Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130)

We should prefer functional style as the method style is deprecated
https://github.com/llvm/mlir-www/blob/main/website/content/deprecation/_index.md#deprecated
(https://mlir.llvm.org/deprecation/)
pull/3141/head
penguin_wwy 2024-04-11 21:47:35 +08:00 committed by GitHub
parent 308c45e61a
commit d4a30b7e67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 405 additions and 409 deletions

View File

@ -178,7 +178,7 @@ struct OpBinder {
}
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
for (auto element : arrayAttr) {
auto integerAttr = element.dyn_cast<IntegerAttr>();
auto integerAttr = dyn_cast<IntegerAttr>(element);
if (!integerAttr)
return failure();
IntegerType t = cast<IntegerType>(integerAttr.getType());
@ -200,7 +200,7 @@ struct OpBinder {
return success();
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
for (auto element : arrayAttr) {
StringAttr stringAttr = element.dyn_cast<StringAttr>();
StringAttr stringAttr = dyn_cast<StringAttr>(element);
if (!stringAttr)
return failure();
values.push_back(stringAttr.getValue().str());

View File

@ -94,7 +94,7 @@ TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
// Compute the knowledge based on the inferred type.
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype = result_ty.cast<ShapedType>().getElementType();
inferredKnowledge.dtype = cast<ShapedType>(result_ty).getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {

View File

@ -1287,7 +1287,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.getLoc(), axisScalar, finalOffset);
Torch::BaseTensorType resultTensorType =
resultType.cast<Torch::BaseTensorType>();
cast<Torch::BaseTensorType>(resultType);
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
binder.op, "expected result type to have a dtype");
@ -1899,7 +1899,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
// If its a dense resource attr we need to convert to a dense type:
if (DenseResourceElementsAttr rattr =
attr.dyn_cast_or_null<DenseResourceElementsAttr>()) {
dyn_cast_or_null<DenseResourceElementsAttr>(attr)) {
// Bytes are stored in little endian order. Big endian support will
// require swizzling.
if (!Endian::little) {
@ -1916,7 +1916,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
Attribute splattr;
if (isa<SplatElementsAttr>(attr)) {
auto denseAttr = attr.cast<DenseElementsAttr>();
auto denseAttr = cast<DenseElementsAttr>(attr);
splattr = denseAttr.getSplatValue<Attribute>();
}

View File

@ -1366,7 +1366,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
// set the splitted axis to variable shape
llvm::SmallVector<int64_t> intermediateShape(result0Ty.getSizes());
for (auto result : binder.op->getResultTypes()) {
int64_t d = result.cast<Torch::ValueTensorType>().getSizes()[dim];
int64_t d = cast<Torch::ValueTensorType>(result).getSizes()[dim];
intermediateShape[dim] = d == intermediateShape[dim] ? d : -1;
}
@ -1437,7 +1437,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
llvm::SmallVector<int64_t> intermediateShape(result0Ty.getSizes());
for (auto result : binder.op->getResultTypes()) {
int64_t d = result.cast<Torch::ValueTensorType>().getSizes()[dim];
int64_t d = cast<Torch::ValueTensorType>(result).getSizes()[dim];
intermediateShape[dim] = d == intermediateShape[dim] ? d : -1;
}

View File

@ -272,9 +272,9 @@ public:
convertScalarToDtype(rewriter, loc, adaptor.getA(), resultType);
Value operandB =
convertScalarToDtype(rewriter, loc, adaptor.getB(), resultType);
if (resultType.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(resultType)) {
rewriter.replaceOpWithNewOp<arith::AddFOp>(op, operandA, operandB);
} else if (resultType.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(resultType)) {
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, operandA, operandB);
} else {
return rewriter.notifyMatchFailure(

View File

@ -1881,7 +1881,7 @@ public:
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
auto inputElementType = getElementTypeOrSelf(input.getType());
if (!inputElementType.isa<ComplexType>()) {
if (!isa<ComplexType>(inputElementType)) {
return op.emitError("only ComplexType is allowed as input type");
}
Type elementType = resultType.getElementType();

View File

@ -131,7 +131,7 @@ public:
auto resultTy = op.getType().cast<ValueTensorType>();
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<TensorType>().getElementType();
Type elementType = cast<TensorType>(newResultType).getElementType();
auto accumulatorDType = getDefaultAccType(rewriter, resultDTy);
if (accumulatorDType != resultDTy) {
elementType = accumulatorDType;
@ -201,7 +201,7 @@ public:
if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
cast<RankedTensorType>(newResultType).getElementType();
matmul = torch_to_linalg::convertTensorToElementType(
rewriter, loc, matmul, resultElementType);
}
@ -307,7 +307,7 @@ public:
unsigned rhsRank = rhsType.getRank();
Type newResultType = getTypeConverter()->convertType(op.getType());
auto resultType = newResultType.cast<RankedTensorType>();
auto resultType = cast<RankedTensorType>(newResultType);
Type elementType = resultType.getElementType();
// The different cases of torch_matmul op is mentioned here:
@ -600,9 +600,9 @@ public:
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
Type newResultType = getTypeConverter()->convertType(op.getType());
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
Type lhsElementType = lhsType.cast<RankedTensorType>().getElementType();
Type rhsElementType = rhsType.cast<RankedTensorType>().getElementType();
cast<RankedTensorType>(newResultType).getElementType();
Type lhsElementType = cast<RankedTensorType>(lhsType).getElementType();
Type rhsElementType = cast<RankedTensorType>(rhsType).getElementType();
if (lhsType.getRank() != 3 || rhsType.getRank() != 3) {
return rewriter.notifyMatchFailure(
@ -712,9 +712,9 @@ public:
auto weightDTy = weight.getType().cast<RankedTensorType>().getElementType();
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
if (!inputDTy.isa<mlir::FloatType, mlir::IntegerType>() ||
!weightDTy.isa<mlir::FloatType, mlir::IntegerType>() ||
!resultDTy.isa<mlir::FloatType, mlir::IntegerType>())
if (!isa<mlir::FloatType, mlir::IntegerType>(inputDTy) ||
!isa<mlir::FloatType, mlir::IntegerType>(weightDTy) ||
!isa<mlir::FloatType, mlir::IntegerType>(resultDTy))
return op.emitError("unimplemented: non-fp not-int type");
size_t inRank = input.getType().cast<RankedTensorType>().getRank();
size_t numSpatialDims = inRank - 2;
@ -790,9 +790,8 @@ public:
SmallVector<Value> outDims{inBatch, weightBatch};
Value paddedInput;
if (transposed) {
if (!inputDTy.isa<mlir::FloatType>() ||
!weightDTy.isa<mlir::FloatType>() ||
!resultDTy.isa<mlir::FloatType>())
if (!isa<mlir::FloatType>(inputDTy) || !isa<mlir::FloatType>(weightDTy) ||
!isa<mlir::FloatType>(resultDTy))
return rewriter.notifyMatchFailure(
op, "transpose does not support non-fp type yet");
@ -927,10 +926,10 @@ public:
accumulatorDType);
if (bias.getType().isa<Torch::NoneType>()) {
Value c0;
if (accumulatorDType.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(accumulatorDType)) {
c0 = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(accumulatorDType, 0.0));
} else if (accumulatorDType.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(accumulatorDType)) {
c0 = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(accumulatorDType, 0));
}
@ -1021,7 +1020,7 @@ public:
Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
cast<RankedTensorType>(newResultType).getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
@ -1081,7 +1080,7 @@ public:
Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
cast<RankedTensorType>(newResultType).getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
@ -1125,7 +1124,7 @@ public:
Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
cast<RankedTensorType>(newResultType).getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
@ -1203,7 +1202,7 @@ public:
Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
cast<RankedTensorType>(newResultType).getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}

View File

@ -154,7 +154,7 @@ static LogicalResult createPoolingOp(
SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) {
Location loc = op->getLoc();
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
if (!elementType.isa<mlir::FloatType>() && !supportNonFPInput)
if (!isa<mlir::FloatType>(elementType) && !supportNonFPInput)
return op->emitError("unimplemented: non-floating point type");
Value initValue =
@ -217,7 +217,7 @@ private:
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
Value initValue =
rewriter.create<arith::ConstantOp>(op->getLoc(), smallestFPValueAttr);
@ -335,7 +335,7 @@ public:
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(
elementType.cast<mlir::FloatType>().getFloatSemantics(),
cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
@ -416,7 +416,7 @@ public:
// `maxpool2d` contains the result of maxpool2d operation over the input.
auto smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
Value maxPool2d, paddedInput;
SmallVector<Value, 4> outTensorShape;
@ -555,7 +555,7 @@ public:
self.getType().cast<RankedTensorType>().getElementType();
Type resultType = typeConverter->convertType(op.getType());
Type resultElementType =
resultType.cast<RankedTensorType>().getElementType();
cast<RankedTensorType>(resultType).getElementType();
bool ceilMode;
SmallVector<Value, Dim> kernelSizeIntValues;
@ -615,9 +615,9 @@ public:
/*iteratorTypes=*/iteratorTypesAvg,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value avg;
if (resultElementType.isa<mlir::IntegerType>())
if (isa<mlir::IntegerType>(resultElementType))
avg = b.create<arith::DivSIOp>(loc, args[0], divisor);
else if (resultElementType.isa<mlir::FloatType>())
else if (isa<mlir::FloatType>(resultElementType))
avg = b.create<arith::DivFOp>(loc, args[0], divisor);
b.create<linalg::YieldOp>(loc, avg);
})
@ -707,7 +707,7 @@ public:
Type auxTensorElementType = auxTensorType.getElementType();
auto smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
buffVal = rewriter.create<arith::ConstantOp>(loc, elementType,
smallestFPValueAttr);

View File

@ -130,7 +130,7 @@ public:
RankedTensorType resultType = self.getType().cast<RankedTensorType>();
Type elemTy = resultType.getElementType();
if (!elemTy.isa<mlir::FloatType>())
if (!isa<mlir::FloatType>(elemTy))
return rewriter.notifyMatchFailure(op, "This op only support float type");
if (!generator.getType().isa<Torch::NoneType>())

View File

@ -70,7 +70,7 @@ public:
input.getType().template cast<RankedTensorType>();
Type idxElementType =
getElementTypeOrSelf(typec->convertType(idxResultType));
if (!idxElementType.isa<IntegerType>())
if (!isa<IntegerType>(idxElementType))
return rewriter.notifyMatchFailure(
op, opName + " to linalg.* requires integer-like result type");
@ -89,8 +89,8 @@ public:
Type inElementType = inputType.getElementType();
bool isUnsigned = false;
if (!inElementType.isa<mlir::FloatType>()) {
if (inElementType.isa<mlir::IntegerType>()) {
if (!isa<mlir::FloatType>(inElementType)) {
if (isa<mlir::IntegerType>(inElementType)) {
auto integerTy = op.getSelf()
.getType()
.template cast<BaseTensorType>()
@ -121,22 +121,21 @@ public:
loc, getAsOpFoldResult(resultShape), inElementType);
Value fillValue;
if (inElementType.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(inElementType)) {
fillValue = rewriter.create<arith::ConstantOp>(
loc,
rewriter.getFloatAttr(
inElementType,
APFloat::getInf(
inElementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/isMax)));
loc, rewriter.getFloatAttr(
inElementType,
APFloat::getInf(
cast<mlir::FloatType>(inElementType).getFloatSemantics(),
/*Negative=*/isMax)));
} else if (!isUnsigned) {
auto width = inElementType.cast<mlir::IntegerType>().getWidth();
auto width = cast<mlir::IntegerType>(inElementType).getWidth();
auto init = isMax ? APSInt::getSignedMinValue(width)
: APSInt::getSignedMaxValue(width);
fillValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(inElementType, init));
} else if (isUnsigned) {
auto width = inElementType.cast<mlir::IntegerType>().getWidth();
auto width = cast<mlir::IntegerType>(inElementType).getWidth();
auto init = isMax ? APInt::getMinValue(width) : APInt::getMaxValue(width);
fillValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(inElementType, init));
@ -180,7 +179,7 @@ public:
rewriter.create<linalg::IndexOp>(loc, dim));
Value resultVal, predicate;
if (inElementType.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(inElementType)) {
arith::CmpFPredicate predType;
if (isMax) {
predType = arith::CmpFPredicate::OGT;
@ -300,21 +299,21 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
if (isa<AtenProdDimIntOp>(op)) {
if (elementType.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(elementType))
return b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, 1.0));
else if (elementType.isa<mlir::IntegerType>())
else if (isa<mlir::IntegerType>(elementType))
return b.create<arith::ConstantOp>(loc, b.getIntegerAttr(elementType, 1));
}
if (isa<AtenMaxOp>(op)) {
if (elementType.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(elementType))
return b.create<arith::ConstantOp>(
loc, b.getFloatAttr(
elementType,
APFloat::getInf(
elementType.cast<mlir::FloatType>().getFloatSemantics(),
cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true)));
else if (elementType.isa<mlir::IntegerType>() &&
else if (isa<mlir::IntegerType>(elementType) &&
elementType.getIntOrFloatBitWidth() != 8)
return b.create<arith::ConstantOp>(
loc, b.getIntegerAttr(elementType,
@ -323,14 +322,14 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
}
if (isa<AtenMinOp>(op)) {
if (elementType.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(elementType))
return b.create<arith::ConstantOp>(
loc, b.getFloatAttr(
elementType,
APFloat::getInf(
elementType.cast<mlir::FloatType>().getFloatSemantics(),
cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/false)));
else if (elementType.isa<mlir::IntegerType>() &&
else if (isa<mlir::IntegerType>(elementType) &&
elementType.getIntOrFloatBitWidth() != 8)
return b.create<arith::ConstantOp>(
loc, b.getIntegerAttr(elementType,
@ -359,25 +358,25 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
Value self =
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
Value result = payloadArgs[1];
if (resultElementType.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(resultElementType))
return b.create<arith::AddFOp>(loc, self, result);
else if (resultElementType.isa<mlir::IntegerType>())
else if (isa<mlir::IntegerType>(resultElementType))
return b.create<arith::AddIOp>(loc, self, result);
} else if (isa<AtenProdDimIntOp>(op)) {
Value self =
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
Value result = payloadArgs[1];
if (resultElementType.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(resultElementType))
return b.create<arith::MulFOp>(loc, self, result);
else if (resultElementType.isa<mlir::IntegerType>())
else if (isa<mlir::IntegerType>(resultElementType))
return b.create<arith::MulIOp>(loc, self, result);
} else if (auto max = dyn_cast<AtenMaxOp>(op)) {
Value self =
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
Value result = payloadArgs[1];
if (resultElementType.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(resultElementType))
return b.create<arith::MaximumFOp>(loc, self, result);
else if (resultElementType.isa<mlir::IntegerType>()) {
else if (isa<mlir::IntegerType>(resultElementType)) {
IntegerType intType = max.getSelf()
.getType()
.cast<BaseTensorType>()
@ -392,9 +391,9 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
Value self =
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
Value result = payloadArgs[1];
if (resultElementType.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(resultElementType))
return b.create<arith::MinimumFOp>(loc, self, result);
else if (resultElementType.isa<mlir::IntegerType>()) {
else if (isa<mlir::IntegerType>(resultElementType)) {
IntegerType intType = min.getSelf()
.getType()
.cast<BaseTensorType>()
@ -626,10 +625,10 @@ private:
ConversionPatternRewriter &rewriter) const {
if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op) ||
isa<AtenNormScalarOp>(op)) &&
!elemType.isa<mlir::FloatType>())
!isa<mlir::FloatType>(elemType))
return rewriter.notifyMatchFailure(
op, "only float types are valid for vector norm ops");
if (isa<AtenAllDimOp>(op) && elemType.isa<mlir::IntegerType>() &&
if (isa<AtenAllDimOp>(op) && isa<mlir::IntegerType>(elemType) &&
elemType.getIntOrFloatBitWidth() == 8)
return rewriter.notifyMatchFailure(op, "uint8 is not supported");

View File

@ -100,7 +100,7 @@ public:
}
Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<RankedTensorType>().getElementType();
Type elementType = cast<RankedTensorType>(newResultType).getElementType();
Value castedValue =
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType);
@ -553,7 +553,7 @@ public:
// The size of the result is calculated as follows:
// ceil((end - start)/step)
Value resultShape;
if (dtype.isa<mlir::IntegerType>()) {
if (isa<mlir::IntegerType>(dtype)) {
Value subOut = rewriter.create<arith::SubIOp>(loc, end, start);
resultShape = rewriter.create<arith::CeilDivSIOp>(loc, subOut, step);
} else {
@ -585,7 +585,7 @@ public:
index = castIndexToInt64(b, loc, index);
index = convertScalarToDtype(b, loc, index, dtype);
Value mulOut, result;
if (dtype.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(dtype)) {
mulOut = b.create<arith::MulFOp>(loc, step, index);
result = b.create<arith::AddFOp>(loc, start, mulOut);
} else {

View File

@ -35,16 +35,16 @@ using namespace mlir::torch::Torch;
template <typename elementType> static bool hasElementType(Value tensor) {
auto tensorType = tensor.getType().cast<RankedTensorType>();
Type tensorElementType = tensorType.getElementType();
return tensorElementType.isa<elementType>();
return isa<elementType>(tensorElementType);
}
template <arith::CmpFPredicate fpred, arith::CmpIPredicate iupred,
arith::CmpIPredicate ispred>
static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type,
Value lhs, Value rhs) {
if (type.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(type))
return b.create<arith::CmpFOp>(loc, fpred, lhs, rhs);
if (IntegerType intType = type.dyn_cast<mlir::IntegerType>()) {
if (IntegerType intType = dyn_cast<mlir::IntegerType>(type)) {
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, iupred, lhs, rhs);
if (intType.isSigned())
@ -319,7 +319,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(bitwiseAndScalar.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::IntegerType>()) {
if (!isa<mlir::IntegerType>(dtype)) {
bitwiseAndScalar.emitError(
"bitwise_and.Scalar does not support non-integer input dtype.");
return nullptr;
@ -371,7 +371,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(bitwiseRightShiftTensor.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::IntegerType>()) {
if (!isa<mlir::IntegerType>(dtype)) {
bitwiseRightShiftTensor.emitError(
"Bitwise_Right_Shift op does not support non-integer input dtype.");
return nullptr;
@ -385,7 +385,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::IntegerType>()) {
if (!isa<mlir::IntegerType>(dtype)) {
bitwiseLeftShiftTensor.emitError(
"Bitwise_Left_Shift op does not support non-integer input dtype.");
return nullptr;
@ -623,7 +623,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
if (dtype.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(dtype)) {
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
return b.create<arith::AddFOp>(loc, lhs, scaled);
} else {
@ -647,7 +647,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType,
/*originalScalar=*/sub.getAlpha());
if (dtype.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(dtype)) {
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
return b.create<arith::SubFOp>(loc, lhs, scaled);
} else {
@ -664,10 +664,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value alpha = convertScalarToDtype(
b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(),
/*dstOriginalDtype=*/dtype);
if (dtype.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(dtype)) {
Value mult = b.create<arith::MulFOp>(loc, other, alpha);
return b.create<arith::SubFOp>(loc, self, mult);
} else if (dtype.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(dtype)) {
Value mult = b.create<arith::MulIOp>(loc, other, alpha);
return b.create<arith::SubIOp>(loc, self, mult);
}
@ -690,10 +690,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
if (dtype.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(dtype)) {
Value mult = b.create<arith::MulFOp>(loc, other, alpha);
return b.create<arith::AddFOp>(loc, self, mult);
} else if (dtype.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(dtype)) {
Value mult = b.create<arith::MulIOp>(loc, other, alpha);
return b.create<arith::AddIOp>(loc, self, mult);
}
@ -708,9 +708,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
if (dtype.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(dtype)) {
return b.create<arith::MulFOp>(loc, lhs, rhs);
} else if (dtype.isa<mlir::ComplexType>()) {
} else if (isa<mlir::ComplexType>(dtype)) {
return b.create<complex::MulOp>(loc, lhs, rhs);
} else {
return b.create<arith::MulIOp>(loc, lhs, rhs);
@ -720,7 +720,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(atan2.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(dtype)) {
atan2.emitError("Atan2 requires floating point result type");
return nullptr;
}
@ -759,9 +759,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
if (dtype.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(dtype))
return b.create<arith::DivFOp>(loc, lhs, rhs);
else if (dtype.isa<mlir::IntegerType>()) {
else if (isa<mlir::IntegerType>(dtype)) {
if (dtype.isUnsignedInteger())
return b.create<arith::DivUIOp>(loc, lhs, rhs);
return b.create<arith::DivSIOp>(loc, lhs, rhs);
@ -777,7 +777,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value div;
if (dtype.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(dtype))
div = b.create<arith::DivFOp>(loc, lhs, rhs);
else {
if (dtype.isUnsignedInteger())
@ -798,7 +798,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division.
if (dtype.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(dtype)) {
Value ceil = b.create<math::CeilOp>(loc, div);
Value floor = b.create<math::FloorOp>(loc, div);
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
@ -811,7 +811,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator)
if (dtype.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(dtype))
return b.create<math::FloorOp>(loc, div);
else if (!dtype.isUnsignedInteger()) {
Type defaultIntToFloatType = b.getF64Type();
@ -831,7 +831,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
Type dtype = pow.getType().cast<ValueTensorType>().getDtype();
if (!dtype.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(dtype)) {
pow.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
@ -857,7 +857,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(pow.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(dtype)) {
pow.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
@ -870,7 +870,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(imag.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(dtype)) {
imag.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
@ -882,7 +882,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(real.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(dtype)) {
real.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
@ -898,10 +898,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (dtype.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(dtype))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) {
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor args from integer to float.
gtScalar.emitError(
@ -928,10 +928,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (dtype.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(dtype))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGE,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) {
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor args from integer to float.
geScalar.emitError(
@ -955,7 +955,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (dtype.isa<mlir::IntegerType>()) {
if (isa<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
eqScalar.emitError(
@ -971,7 +971,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (dtype.isa<mlir::IntegerType>()) {
if (isa<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
neScalar.emitError(
@ -989,10 +989,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
// a lot of code that can be refactored.
if (dtype.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(dtype))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) {
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
ltScalar.emitError(
@ -1017,10 +1017,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
// TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code
// that can be refactored.
if (dtype.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(dtype))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) {
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
leScalar.emitError(
@ -1096,14 +1096,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(clamp.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType, mlir::IntegerType>()) {
if (!isa<mlir::FloatType, mlir::IntegerType>(dtype)) {
clamp.emitError("unimplement type for clamp");
return nullptr;
}
Type dstOriginalDtype = clamp.getType().cast<BaseTensorType>().getDtype();
bool isUnsigned = isa<QUInt8Type>(dstOriginalDtype);
if (auto intTy = dstOriginalDtype.dyn_cast<IntegerType>()) {
if (auto intTy = dyn_cast<IntegerType>(dstOriginalDtype)) {
isUnsigned = intTy.isUnsigned();
}
auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value {
@ -1112,11 +1112,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
/*dstOriginalDtype=*/dstOriginalDtype);
Value pred;
if (dtype.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(dtype)) {
auto cmp =
getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT;
pred = b.create<arith::CmpFOp>(loc, cmp, input, clamp);
} else if (dtype.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(dtype)) {
auto cmp =
isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt;
if (getMax)
@ -1151,10 +1151,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
isMinNone = false;
auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value pred;
if (dtype.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(dtype)) {
pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, result,
minPromoted);
} else if (dtype.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(dtype)) {
pred = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, result,
minPromoted);
} else {
@ -1169,10 +1169,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
max = isMinNone ? payloadArgs[1] : payloadArgs[2];
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
Value pred;
if (dtype.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(dtype)) {
pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, result,
maxPromoted);
} else if (dtype.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(dtype)) {
pred = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, result,
maxPromoted);
} else {
@ -1194,10 +1194,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value alpha = convertScalarToDtype(
b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(),
/*dstOriginalDtype=*/dtype);
if (dtype.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(dtype)) {
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
return b.create<arith::SubFOp>(loc, other, mult);
} else if (dtype.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(dtype)) {
Value mult = b.create<arith::MulIOp>(loc, self, alpha);
return b.create<arith::SubIOp>(loc, other, mult);
}
@ -1211,9 +1211,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, operands[1], dtype);
if (dtype.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(dtype))
return b.create<arith::MulFOp>(loc, lhs, rhs);
if (dtype.isa<mlir::IntegerType>())
if (isa<mlir::IntegerType>(dtype))
return b.create<arith::MulIOp>(loc, lhs, rhs);
mulScalar.emitError("unimplemented: Only integer/float dtype supported");
return nullptr;
@ -1246,7 +1246,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(divScalar.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(dtype)) {
divScalar.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
@ -1263,9 +1263,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value other = convertScalarToDtype(b, loc, operands[1], newResultType);
Value result;
if (newResultType.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(newResultType)) {
result = b.create<arith::RemFOp>(loc, self, other);
} else if (newResultType.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(newResultType)) {
result = b.create<arith::RemSIOp>(loc, self, other);
} else {
remScalar.emitError(
@ -1283,9 +1283,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
Value result;
if (newResultType.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(newResultType)) {
result = b.create<arith::RemFOp>(loc, self, other);
} else if (newResultType.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(newResultType)) {
result = b.create<arith::RemSIOp>(loc, self, other);
} else {
remTensor.emitError(
@ -1303,12 +1303,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
Value result;
if (newResultType.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(newResultType)) {
Value n = b.create<arith::DivFOp>(loc, self, other);
n = b.create<math::TruncOp>(loc, n);
Value n_y = b.create<arith::MulFOp>(loc, n, other);
result = b.create<arith::SubFOp>(loc, self, n_y);
} else if (newResultType.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(newResultType)) {
Value n = b.create<arith::DivSIOp>(loc, self, other);
Value n_y = b.create<arith::MulIOp>(loc, n, other);
result = b.create<arith::SubIOp>(loc, self, n_y);
@ -1349,7 +1349,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value value = convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
Value predicate;
if (dtype.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(dtype))
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
threshold);
else
@ -1372,7 +1372,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value constantZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
Value predicate;
if (dtype.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(dtype))
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
threshold);
else
@ -1426,7 +1426,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type elementType = converter->convertType(bitwiseNot.getType())
.cast<RankedTensorType>()
.getElementType();
if (elementType.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(elementType)) {
bitwiseNot.emitError("Bitwise_Not does not support floating point dtype");
return nullptr;
}
@ -2253,7 +2253,7 @@ public:
auto inputType = input.getType().cast<RankedTensorType>();
auto inputElementType = inputType.getElementType();
if (!inputElementType.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(inputElementType)) {
op.emitError("Logit does not support non-floating point type");
return failure();
}

View File

@ -554,7 +554,7 @@ FailureOr<Type> torch_to_linalg::getBackendTypeForScalarType(
}
Type type = *maybeType;
// The linalg-on-tensors backend currently expects integers to be signless.
if (auto intType = type.dyn_cast<IntegerType>()) {
if (auto intType = dyn_cast<IntegerType>(type)) {
type = IntegerType::get(context, intType.getWidth(), IntegerType::Signless);
}
return type;

View File

@ -140,11 +140,11 @@ public:
// If the target type is non-torch type, then use TypeConverter to convert
// the type of the source.
if (targetType.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(targetType)) {
targetType = Torch::FloatType::get(op->getContext());
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfWhileOp.getLoc(), targetType, {to});
} else if (targetType.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(targetType)) {
unsigned bitWidth = targetType.getIntOrFloatBitWidth();
if (bitWidth == 1)
targetType = Torch::BoolType::get(op->getContext());
@ -179,7 +179,7 @@ public:
// If the argument is a torch tensor, directly add it in the list of
// iter args.
if (torchType.isa<Torch::BaseTensorType>()) {
if (isa<Torch::BaseTensorType>(torchType)) {
loopConditionIterArgs.push_back(torchArg);
continue;
}
@ -262,11 +262,11 @@ public:
// If the target type is non-torch type, then use TypeConverter to convert
// the type of the source.
if (targetType.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(targetType)) {
targetType = Torch::FloatType::get(op->getContext());
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfForOp.getLoc(), targetType, {to});
} else if (targetType.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(targetType)) {
unsigned bitWidth = targetType.getIntOrFloatBitWidth();
if (bitWidth == 1)
targetType = Torch::BoolType::get(op->getContext());

View File

@ -42,11 +42,11 @@ static Value getConstantLike(OpBuilder &b, Location loc, T constant,
Value val) {
Type ty = getElementTypeOrSelf(val.getType());
auto getAttr = [&]() -> Attribute {
if (ty.isa<mlir::IntegerType>())
if (isa<mlir::IntegerType>(ty))
return b.getIntegerAttr(ty, constant);
if (ty.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(ty))
return b.getFloatAttr(ty, constant);
if (auto complexTy = ty.dyn_cast<mlir::ComplexType>())
if (auto complexTy = dyn_cast<mlir::ComplexType>(ty))
return complex::NumberAttr::get(complexTy, constant, 0);
llvm_unreachable("unhandled element type");
};
@ -105,17 +105,17 @@ bool skipMultiplyAlpha(Value alphaValue) {
static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
PatternRewriter &rewriter) {
auto constType = RankedTensorType::get({}, elementType);
if (elementType.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(elementType)) {
auto constAttr = SplatElementsAttr::get(
constType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*negative=*/false));
return rewriter
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult();
}
if (elementType.isa<mlir::IntegerType>()) {
auto integerType = elementType.cast<mlir::IntegerType>();
if (isa<mlir::IntegerType>(elementType)) {
auto integerType = cast<mlir::IntegerType>(elementType);
DenseElementsAttr constAttr;
if (integerType.isUnsigned()) {
constAttr = SplatElementsAttr::get(
@ -134,17 +134,17 @@ static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
PatternRewriter &rewriter) {
auto constType = RankedTensorType::get({}, elementType);
if (elementType.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(elementType)) {
auto constAttr = SplatElementsAttr::get(
constType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*negative=*/true));
return rewriter
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult();
}
if (elementType.isa<mlir::IntegerType>()) {
auto integerType = elementType.cast<mlir::IntegerType>();
if (isa<mlir::IntegerType>(elementType)) {
auto integerType = cast<mlir::IntegerType>(elementType);
DenseElementsAttr constAttr;
if (integerType.isUnsigned()) {
constAttr = SplatElementsAttr::get(
@ -446,7 +446,7 @@ public:
op, "only support constant str rounding mode");
// if trunc and int, do nothing
if (roundingMode == "trunc" && outElemTy.isa<mlir::FloatType>()) {
if (roundingMode == "trunc" && isa<mlir::FloatType>(outElemTy)) {
// "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division.
auto sign = rewriter.create<stablehlo::SignOp>(loc, result);
@ -457,7 +457,7 @@ public:
if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator)
if (outElemTy.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(outElemTy))
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
else if (!outElemTy.isUnsignedInteger()) {
TensorType defaultIntToFloatType =
@ -518,10 +518,10 @@ public:
chlo::ComparisonTypeAttr compareTypeAttr;
chlo::ComparisonDirectionAttr compareDirectionAttr;
if (lhsElemTy.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(lhsElemTy)) {
compareTypeAttr = chlo::ComparisonTypeAttr::get(
op->getContext(), chlo::ComparisonType::FLOAT);
} else if (lhsElemTy.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(lhsElemTy)) {
compareTypeAttr = chlo::ComparisonTypeAttr::get(
op->getContext(), chlo::ComparisonType::SIGNED);
}
@ -985,14 +985,14 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
auto lhsTy = lhs.getType().cast<RankedTensorType>();
auto lhsElemTy = lhsTy.getElementType();
if (!lhsElemTy.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(lhsElemTy)) {
return op->emitError("only float tensor in relu op is supported");
}
Value zeroTensor;
zeroTensor = getConstantLike(
rewriter, op->getLoc(),
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
APFloat::getZero(cast<mlir::FloatType>(lhsElemTy).getFloatSemantics(),
false),
lhs);
rewriter.replaceOpWithNewOp<stablehlo::MaxOp>(op, lhs, zeroTensor);
@ -1160,7 +1160,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
rewriter.getI64IntegerAttr(feature_index));
output = hlo::promoteType(rewriter, op.getLoc(),
batchNormTrainingResult.getResult(0),
outputTy.cast<TensorType>());
cast<TensorType>(outputTy));
} else {
auto batchNormTrainingResult =
rewriter.create<stablehlo::BatchNormTrainingOp>(
@ -1204,7 +1204,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
runningVar, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(feature_index));
output = hlo::promoteType(rewriter, op.getLoc(), bnResult,
outputTy.cast<TensorType>());
cast<TensorType>(outputTy));
} else {
output = rewriter.create<stablehlo::BatchNormInferenceOp>(
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
@ -1478,7 +1478,7 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
->convertType(op.getType())
.cast<RankedTensorType>();
auto dtype = outType.getElementType();
if (!dtype.isa<mlir::IntegerType>() && !dtype.isa<mlir::FloatType>()) {
if (!isa<mlir::IntegerType>(dtype) && !isa<mlir::FloatType>(dtype)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only int or float dtype supported");
}
@ -1607,7 +1607,7 @@ LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
auto shape_tensor = rewriter.create<stablehlo::ConstantOp>(
loc, rewriter.getI64TensorAttr(elements));
auto outTy = getTypeConverter()->convertType(op.getType());
auto outElemTy = outTy.cast<RankedTensorType>().getElementType();
auto outElemTy = cast<RankedTensorType>(outTy).getElementType();
Value from =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy);
Value to =

View File

@ -34,14 +34,14 @@ static Value createInitialValueForGatherScatterOp(Operation *op,
PatternRewriter &rewriter) {
auto elementTy = constType.getElementType();
if (isa<AtenEmbeddingBagPaddingIdxOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getZero(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/false)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (elementTy.isa<mlir::IntegerType>() &&
} else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});

View File

@ -37,14 +37,14 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
// Avg pooling
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp,
AtenCumsumOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getZero(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/false)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (elementTy.isa<mlir::IntegerType>() &&
} else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
@ -55,14 +55,14 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
// Max pooling
if (isa<AtenMaxPool2dOp, AtenMaxPool2dWithIndicesOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getInf(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true)});
constType,
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/true)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (elementTy.isa<mlir::IntegerType>() &&
} else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType,

View File

@ -37,14 +37,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
auto constType = RankedTensorType::get({}, elementTy);
if (isa<AtenSumOp, AtenSumDimIntListOp, AtenFrobeniusNormDimOp,
AtenLinalgVectorNormOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getZero(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/false)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (elementTy.isa<mlir::IntegerType>() &&
} else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
@ -54,14 +54,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
}
if (isa<AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getInf(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true)});
constType,
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/true)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (elementTy.isa<mlir::IntegerType>() &&
} else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType,
@ -72,14 +72,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
}
if (isa<AtenMinOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getInf(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/false)});
constType,
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/false)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (elementTy.isa<mlir::IntegerType>() &&
} else if (isa<mlir::IntegerType>(elementTy) &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType,
@ -234,7 +234,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
"only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported!
if (inputElemTy.isa<mlir::IntegerType>() &&
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
@ -305,7 +305,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
"Only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (inputElemTy.isa<mlir::IntegerType>() &&
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
@ -319,7 +319,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
->convertType(op.getResult(1).getType())
.template cast<RankedTensorType>();
Type idxElementType = idxResultType.getElementType();
if (!idxElementType.isa<mlir::IntegerType>()) {
if (!isa<mlir::IntegerType>(idxElementType)) {
return op.emitError("Aten.max.dim needs integer-like result");
}
@ -404,7 +404,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
"only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (inputElemTy.isa<mlir::IntegerType>() &&
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
@ -466,7 +466,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
"only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (inputElemTy.isa<mlir::IntegerType>() &&
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
@ -529,7 +529,7 @@ LogicalResult ConvertAtenReductionOp<AtenMinOp>::matchAndRewrite(
"only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (inputElemTy.isa<mlir::IntegerType>() &&
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
@ -603,7 +603,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
}
// Currently, (u)int8 dtype is not supported
if (inputElemTy.isa<mlir::IntegerType>() &&
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
@ -715,7 +715,7 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
}
auto inputRank = inputType.getRank();
auto inputElemType = inputType.getElementType();
if (!inputElemType.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(inputElemType)) {
return op.emitError(
"only float dtype allowed in input tensor of AtenFrobeniusNormDimOp");
}
@ -830,7 +830,7 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
auto outElemType = outType.getElementType();
if (!outElemType.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(outElemType)) {
return op.emitError("only float dtype allowed in AtenLinalgVectorNormOp");
}
@ -912,7 +912,7 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
op->getLoc(), blockArgumentTy,
DenseElementsAttr::get(
blockArgumentTy,
APFloat(outElemType.cast<mlir::FloatType>().getFloatSemantics(), 1)));
APFloat(cast<mlir::FloatType>(outElemType).getFloatSemantics(), 1)));
auto reciprocalOrd = rewriter.create<stablehlo::DivOp>(
op->getLoc(), blockArgumentTy, constantOne, ord);
auto output = rewriter.create<chlo::BroadcastPowOp>(

View File

@ -144,12 +144,12 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
return rewriter.notifyMatchFailure(op,
"Unable to extract the scalar constant");
if (dtype.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(dtype)) {
tosaTensor = tosa::getConstTensor<float>(rewriter, op,
(isFloat ? doubleValue : intValue),
dshape, dtype)
.value();
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
} else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
auto w = intType.getWidth();
if (w != 1 && w != 32 && w != 64)
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
@ -261,7 +261,7 @@ public:
}
Type rhsAlphaMulElemType;
if (outElemTy.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(outElemTy)) {
rhsAlphaMulElemType = outElemTy;
} else {
// if output type is 64, input type should also be 32
@ -355,7 +355,7 @@ public:
std::is_same<AtenOpT, AtenBitwiseAndTensorOp>() ||
std::is_same<AtenOpT, AtenBitwiseOrTensorOp>() ||
std::is_same<AtenOpT, AtenBitwiseXorTensorOp>();
if (lhsElemTy.isa<mlir::FloatType>() && isBitwiseOp) {
if (isa<mlir::FloatType>(lhsElemTy) && isBitwiseOp) {
return rewriter.notifyMatchFailure(op,
"For bitwise operators, only integer "
"datatype legalization is supported");
@ -442,8 +442,7 @@ public:
rhsTensor = rhsType ? rhs : rhsAsTensor;
}
if (outElemTy.isa<mlir::FloatType>() ||
outElemTy.isa<mlir::IntegerType>()) {
if (isa<mlir::FloatType>(outElemTy) || isa<mlir::IntegerType>(outElemTy)) {
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
@ -1454,7 +1453,7 @@ public:
SmallVector<int64_t> matmulOutputShape(
{matmulLhsShape[0], matmulLhsShape[1], matmulRhsShape[2]});
Type outputElemTy;
if (lhsElemTy.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(lhsElemTy)) {
outputElemTy = lhsElemTy;
} else { // qint8 emits i32 matmul output
outputElemTy = rewriter.getIntegerType(32);
@ -1898,7 +1897,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
// TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and
// accumulator) are 48-bit and not 32-bit, and requires the use of APInt to
// define a 48-bit int.
if (inputElemTy.isa<quant::QuantizedType>()) {
if (isa<quant::QuantizedType>(inputElemTy)) {
SmallVector<int32_t> zeroVec(weightShape[0], 0);
bias = tosa::getConstTensor<int32_t>(
rewriter, op, zeroVec, {static_cast<int32_t>(weightShape[0])})
@ -1915,7 +1914,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
op, "Bias provided but not a ranked tensor");
}
auto biasElemTy =
inputElemTy.isa<mlir::FloatType>() ? inputElemTy : rewriter.getI32Type();
isa<mlir::FloatType>(inputElemTy) ? inputElemTy : rewriter.getI32Type();
int64_t groups;
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) {
@ -2098,7 +2097,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
.getResult();
Value rescaledResult = transposedOutput;
if (inputElemTy.isa<quant::QuantizedType>()) {
if (isa<quant::QuantizedType>(inputElemTy)) {
rescaledResult = tosa::buildRescaleOpConvOutput(
rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
}
@ -2230,7 +2229,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
if (toBcastType.getRank() > 1)
return rewriter.notifyMatchFailure(op, "Rank cannot be more than 1");
RankedTensorType outTensorType = outType.cast<RankedTensorType>();
RankedTensorType outTensorType = cast<RankedTensorType>(outType);
SmallVector<int64_t> newShape = {
makeShapeTorchCompatible(toBcastType.getShape())[0]};
for (auto i = 2; i < outTensorType.getRank(); ++i)
@ -2677,7 +2676,7 @@ LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
op, "Only floating-point or integer datatype legalization supported");
// Integer types with width > 32 are not supported
auto selfIntType = selfElemTy.dyn_cast<IntegerType>();
auto selfIntType = dyn_cast<IntegerType>(selfElemTy);
if (selfIntType && selfIntType.getWidth() > 32) {
return rewriter.notifyMatchFailure(
op, "Integer types with width greater than 32 are not supported");
@ -2956,7 +2955,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
op, "Only tensor types are currently supported");
auto selfElemTy = selfType.getElementType();
if (!selfElemTy.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(selfElemTy)) {
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");
}
@ -2993,7 +2992,7 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
op, "Only tensor types are currently supported");
auto selfElemTy = selfType.getElementType();
if (!selfElemTy.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(selfElemTy)) {
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");
}
@ -3057,7 +3056,7 @@ LogicalResult ConvertAtenOp<AtenHardtanhBackwardOp>::matchAndRewrite(
}
// Integer types with width > 32 are not supported
auto selfIntType = selfElemTy.dyn_cast<IntegerType>();
auto selfIntType = dyn_cast<IntegerType>(selfElemTy);
if (selfIntType && selfIntType.getWidth() > 32) {
return rewriter.notifyMatchFailure(
op, "Integer types with width greater than 32 are not supported");
@ -4235,7 +4234,7 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
"Unable to extract the scalar constant");
auto outElemTy = resultType.getElementType();
if (outElemTy.isa<mlir::IntegerType>()) {
if (isa<mlir::IntegerType>(outElemTy)) {
rewriter.replaceOpWithNewOp<tosa::ConstOp>(
op, resultType, DenseElementsAttr::get(resultType, {intValue}));
} else if (outElemTy.isF64()) {
@ -4383,7 +4382,7 @@ LogicalResult ConvertAtenOp<AtenRemainderScalarOp>::matchAndRewrite(
auto divTensor = self;
// tosa::DivOp only supports int
if (outElemTy.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(outElemTy)) {
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
op.getLoc(), otherTensor.getType(), otherTensor);
divTensor = rewriter.create<tosa::MulOp>(

View File

@ -119,7 +119,7 @@ tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
Value lhs, Value rhs) {
auto lhsElemTy = lhs.getType().cast<TensorType>().getElementType();
auto rhsElemTy = rhs.getType().cast<TensorType>().getElementType();
if (lhsElemTy.isa<mlir::FloatType>() || rhsElemTy.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(lhsElemTy) || isa<mlir::FloatType>(rhsElemTy)) {
(void)rewriter.notifyMatchFailure(op,
"tosa.div only supports integer type");
}
@ -213,7 +213,7 @@ std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
Type outType, Value paramsValue,
Value indicesValue) {
auto resultType = outType.dyn_cast<ShapedType>();
auto resultType = dyn_cast<ShapedType>(outType);
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
@ -419,7 +419,7 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
Operation *op, Type outType,
Value paramsValue, Value indicesValue,
Value fillValues) {
auto resultType = outType.dyn_cast<ShapedType>();
auto resultType = dyn_cast<ShapedType>(outType);
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
auto fillValuesType = fillValues.getType().dyn_cast<RankedTensorType>();
@ -981,7 +981,7 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
return std::nullopt;
Type elemType = output_type.getElementType();
if (!elemType.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(elemType)) {
op->emitOpError("Only floating-point datatype legalization supported for "
"AtenLinalgVectorNorm op");
return std::nullopt;

View File

@ -154,7 +154,7 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
// Create a zero constant tensor of the desired type and shape.
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
Operation *op, Type type) {
RankedTensorType resultType = type.dyn_cast<RankedTensorType>();
RankedTensorType resultType = dyn_cast<RankedTensorType>(type);
if (!resultType) {
(void)rewriter.notifyMatchFailure(op, "not ranked tensor type");
@ -167,7 +167,7 @@ std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
Attribute zeroAttr = rewriter.getZeroAttr(zeroType);
return CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(), zeroType,
zeroAttr.cast<ElementsAttr>())
cast<ElementsAttr>(zeroAttr))
.getResult();
}
@ -312,7 +312,7 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
Value src, Type destType, Value &result) {
Type srcElemTy = src.getType().dyn_cast<TensorType>().getElementType();
Type destElemTy = destType.dyn_cast<TensorType>().getElementType();
Type destElemTy = dyn_cast<TensorType>(destType).getElementType();
if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
return rewriter.notifyMatchFailure(
@ -392,7 +392,7 @@ LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
// Tosa supports FP16 and FP32 accumulator type for FP16 input. When the time
// FP16 is supported, the accumulator type can be selected based on trade-off
// between performance and accuracy. Set to FP32 by default.
accType = inputETy.isa<FloatType>()
accType = isa<FloatType>(inputETy)
? mlir::TypeAttr::get(rewriter.getF32Type())
: mlir::TypeAttr::get(rewriter.getIntegerType(32));

View File

@ -27,9 +27,9 @@ LogicalResult verifyLinalgCompatibleTypes(Operation *op,
// TODO: Remove this check but use a separate verification pass to verify the
// invariants expected by later passes.
auto isValidLinalgType = [](Type type) {
if (type.isa<NonValueTensorType>())
if (isa<NonValueTensorType>(type))
return false;
auto tensor = type.dyn_cast<ValueTensorType>();
auto tensor = dyn_cast<ValueTensorType>(type);
return !tensor ||
tensor.toBuiltinTensor().dyn_cast_or_null<RankedTensorType>();
};
@ -43,8 +43,8 @@ LogicalResult verifyLinalgCompatibleTypes(Operation *op,
LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v) {
Type type = v.getType();
if (type.isa<OptionalType>() || type.isa<Torch::NoneType>() ||
type.isa<mlir::NoneType>())
if (isa<OptionalType>(type) || isa<Torch::NoneType>(type) ||
isa<mlir::NoneType>(type))
return rewriter.notifyMatchFailure(op, "unimplemented None type arg");
return success();
}
@ -104,7 +104,7 @@ void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
Type lhsType = lhsDim.getType();
Type rhsType = rhsDim.getType();
auto checkIntOrIndex = [](Type type) {
assert((type.isa<IntegerType>() || type.isa<IndexType>()) &&
assert((isa<IntegerType>(type) || isa<IndexType>(type)) &&
"must be either integer or index type");
};
checkIntOrIndex(lhsType);
@ -198,13 +198,13 @@ Value getTensorSize(OpBuilder &b, Location loc, Value tensor) {
// Creates a constant of type `elemType` with value `val`.
Value getConstant(OpBuilder &b, Location loc, int64_t val, Type elemType) {
TypedAttr attr = {};
if (elemType.isa<mlir::FloatType>())
if (isa<mlir::FloatType>(elemType))
attr = b.getFloatAttr(elemType, val);
if (elemType.isa<mlir::IndexType>())
if (isa<mlir::IndexType>(elemType))
attr = b.getIndexAttr(val);
if (elemType.isa<mlir::IntegerType>())
attr = b.getIntegerAttr(
elemType, APInt(elemType.cast<IntegerType>().getWidth(), val));
if (isa<mlir::IntegerType>(elemType))
attr = b.getIntegerAttr(elemType,
APInt(cast<IntegerType>(elemType).getWidth(), val));
if (!attr)
return nullptr;
return b.create<arith::ConstantOp>(loc, elemType, attr);
@ -264,7 +264,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
return scalar;
auto isByteOrChar = [](Type type) {
if (auto integerTy = type.dyn_cast<mlir::IntegerType>()) {
if (auto integerTy = dyn_cast<mlir::IntegerType>(type)) {
return integerTy.getWidth() == 8;
}
return false;
@ -303,10 +303,10 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
if (dtype.isSignlessInteger(1)) {
Type scalarType = scalar.getType();
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(scalarType));
if (scalarType.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(scalarType)) {
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, scalar,
cstZero);
} else if (scalarType.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(scalarType)) {
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, scalar,
cstZero);
} else {
@ -317,14 +317,14 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
}
}
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
if (auto dtypeFloat = dyn_cast<mlir::FloatType>(dtype)) {
if (auto scalarFloat = dyn_cast<mlir::FloatType>(scalarType)) {
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
return b.create<arith::TruncFOp>(loc, dtype, scalar);
// Only scalarFloat width < dtypeFloat width can reach here.
return b.create<arith::ExtFOp>(loc, dtype, scalar);
}
assert(scalarType.isa<mlir::IntegerType>());
assert(isa<mlir::IntegerType>(scalarType));
if (scalarType.isSignlessInteger(1) ||
(srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger()))
return b.create<arith::UIToFPOp>(loc, dtype, scalar);
@ -333,11 +333,11 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
return b.create<arith::SIToFPOp>(loc, dtype, scalar);
}
if (auto dtypeInteger = dtype.dyn_cast<mlir::IntegerType>()) {
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>())
if (auto dtypeInteger = dyn_cast<mlir::IntegerType>(dtype)) {
if (auto scalarFloat = dyn_cast<mlir::FloatType>(scalarType))
return b.create<arith::FPToSIOp>(loc, dtype, scalar);
assert(scalarType.isa<mlir::IntegerType>());
auto scalarInteger = scalarType.cast<mlir::IntegerType>();
assert(isa<mlir::IntegerType>(scalarType));
auto scalarInteger = cast<mlir::IntegerType>(scalarType);
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
return b.create<arith::TruncIOp>(loc, dtype, scalar);
if (scalarType.isSignlessInteger(1) ||

View File

@ -49,7 +49,7 @@ allocateBuffersForResults(Location loc, TMTensorOp tmtensorOp,
size_t resultIndex = en.index();
Type resultType = en.value();
auto tensorType = resultType.dyn_cast<RankedTensorType>();
auto tensorType = dyn_cast<RankedTensorType>(resultType);
if (tensorType == nullptr) {
tmtensorOp.emitOpError()
<< "tensor to buffer conversion expects ranked tensor results";

View File

@ -100,10 +100,12 @@ void TorchDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.cpp.inc"
>();
addInterfaces<TorchInlinerInterface>();
}
@ -144,35 +146,34 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
Operation *TorchDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto integerType = type.dyn_cast<Torch::IntType>())
return builder.create<Torch::ConstantIntOp>(loc, value.cast<IntegerAttr>());
if (auto integerType = dyn_cast<Torch::IntType>(type))
return builder.create<Torch::ConstantIntOp>(loc, cast<IntegerAttr>(value));
if (auto floatType = type.dyn_cast<Torch::FloatType>())
return builder.create<Torch::ConstantFloatOp>(loc, value.cast<FloatAttr>());
if (auto floatType = dyn_cast<Torch::FloatType>(type))
return builder.create<Torch::ConstantFloatOp>(loc, cast<FloatAttr>(value));
if (auto numberType = type.dyn_cast<Torch::NumberType>()) {
if (auto floatValue = value.dyn_cast<mlir::FloatAttr>()) {
if (auto numberType = dyn_cast<Torch::NumberType>(type)) {
if (auto floatValue = dyn_cast<mlir::FloatAttr>(value)) {
return builder.create<Torch::ConstantNumberOp>(loc, floatValue);
} else if (auto intValue = value.dyn_cast<mlir::IntegerAttr>()) {
} else if (auto intValue = dyn_cast<mlir::IntegerAttr>(value)) {
return builder.create<Torch::ConstantNumberOp>(loc, intValue);
}
}
if (type.isa<Torch::BoolType>()) {
return builder.create<Torch::ConstantBoolOp>(loc,
value.cast<IntegerAttr>());
if (isa<Torch::BoolType>(type)) {
return builder.create<Torch::ConstantBoolOp>(loc, cast<IntegerAttr>(value));
}
if (type.isa<Torch::NoneType>())
if (isa<Torch::NoneType>(type))
return builder.create<ConstantNoneOp>(loc);
if (auto stringAttr = value.dyn_cast<StringAttr>())
if (auto stringAttr = dyn_cast<StringAttr>(value))
return builder.create<ConstantStrOp>(loc, stringAttr);
if (auto elementsAttr = value.dyn_cast<ElementsAttr>()) {
if (auto elementsAttr = dyn_cast<ElementsAttr>(value)) {
// Only !torch.vtensor can be constant folded. !torch.tensor has
// non-trivial aliasing semantics which prevent deduplicating it.
assert(type.isa<ValueTensorType>() && "should be a vtensor type!");
assert(isa<ValueTensorType>(type) && "should be a vtensor type!");
return builder.create<ValueTensorLiteralOp>(loc, elementsAttr);
}

View File

@ -41,9 +41,8 @@ Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder,
return value;
// If the type is a tensor, then adjust the static information.
if ((type.isa<ValueTensorType>() && desiredType.isa<ValueTensorType>()) ||
(type.isa<NonValueTensorType>() &&
desiredType.isa<NonValueTensorType>())) {
if ((isa<ValueTensorType>(type) && isa<ValueTensorType>(desiredType)) ||
(isa<NonValueTensorType>(type) && isa<NonValueTensorType>(desiredType))) {
Value adjusted = builder.create<TensorStaticInfoCastOp>(value.getLoc(),
desiredType, value);
return adjusted;
@ -90,7 +89,7 @@ Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
// then we do the copy by going to a value tensor and back.
if (tensor.getType().isa<NonValueTensorType>())
tensor = builder.create<CopyToValueTensorOp>(loc, tensor);
if (newType.isa<NonValueTensorType>())
if (isa<NonValueTensorType>(newType))
tensor = builder.create<CopyToNonValueTensorOp>(loc, tensor);
return tensor;
@ -132,11 +131,11 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
static Value getScalarIntValue(Value input, Location loc,
PatternRewriter &rewriter) {
auto inputType = input.getType();
if (inputType.isa<Torch::IntType>()) {
if (isa<Torch::IntType>(inputType)) {
return input;
}
auto inputTensorType = inputType.dyn_cast<BaseTensorType>();
auto inputTensorType = dyn_cast<BaseTensorType>(inputType);
if (!inputTensorType)
return nullptr;
@ -166,11 +165,11 @@ static Value getScalarIntValue(Value input, Location loc,
static Value getScalarFloatValue(Value input, Location loc,
PatternRewriter &rewriter) {
auto inputType = input.getType();
if (inputType.isa<Torch::FloatType>()) {
if (isa<Torch::FloatType>(inputType)) {
return input;
}
auto inputTensorType = inputType.dyn_cast<BaseTensorType>();
auto inputTensorType = dyn_cast<BaseTensorType>(inputType);
if (!inputTensorType)
return nullptr;
@ -273,7 +272,7 @@ LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
LogicalResult PrimListConstructOp::verify() {
auto resultType = getResult().getType();
auto resultElementType = resultType.dyn_cast<ListType>().getContainedType();
auto resultElementType = dyn_cast<ListType>(resultType).getContainedType();
auto matchResultElementType = [&](Type type) {
return isValidSubtype(type, resultElementType);
};
@ -606,7 +605,7 @@ static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) {
Type rhsType = rhs.getType();
// If either type is a NoneType, make it be the lhsType.
if (rhsType.isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(rhsType)) {
std::swap(lhsType, rhsType);
std::swap(lhs, rhs);
}
@ -615,14 +614,14 @@ static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) {
// If both types are the singleton `!torch.none` type, then we don't even need
// to look at the values.
if (lhsType.isa<Torch::NoneType>() && rhsType.isa<Torch::NoneType>())
if (isa<Torch::NoneType>(lhsType) && isa<Torch::NoneType>(rhsType))
return IntegerAttr::get(IntegerType::get(op->getContext(), 1), equalIsTrue);
// If neither type is a subtype of the other, then the result is false.
// TODO: Implement and use subtype infra for this.
// For now, check a specific case.
// If the rhs is not OptionalType, then we know it cannot be None.
if (lhsType.isa<Torch::NoneType>() && !rhsType.isa<Torch::OptionalType>()) {
if (isa<Torch::NoneType>(lhsType) && !isa<Torch::OptionalType>(rhsType)) {
return IntegerAttr::get(IntegerType::get(op->getContext(), 1),
!equalIsTrue);
}
@ -640,9 +639,9 @@ OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) {
auto step = adaptor.getStep();
if (!lo || !hi || !step)
return nullptr;
auto loInt = lo.dyn_cast_or_null<IntegerAttr>().getValue();
auto hiInt = hi.dyn_cast_or_null<IntegerAttr>().getValue();
auto stepInt = step.dyn_cast_or_null<IntegerAttr>().getValue();
auto loInt = dyn_cast_or_null<IntegerAttr>(lo).getValue();
auto hiInt = dyn_cast_or_null<IntegerAttr>(hi).getValue();
auto stepInt = dyn_cast_or_null<IntegerAttr>(step).getValue();
// TODO: Implement folding for negative steps.
if (stepInt.isNegative())
return nullptr;
@ -650,7 +649,7 @@ OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) {
// r[i] = lo + step*i such that i >= 0 and r[i] < hi
// So maximize `i` such that lo + step * i < hi
// ==> i == ceildiv(hi - lo, step)
return IntegerAttr::get(lo.cast<TypedAttr>().getType(),
return IntegerAttr::get(cast<TypedAttr>(lo).getType(),
llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt,
APInt::Rounding::UP));
}
@ -665,10 +664,10 @@ OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) {
auto step = adaptor.getStep();
if (!index || !start || !step)
return nullptr;
auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue();
auto startInt = start.dyn_cast_or_null<IntegerAttr>().getValue();
auto stepInt = step.dyn_cast_or_null<IntegerAttr>().getValue();
return IntegerAttr::get(index.cast<TypedAttr>().getType(),
auto indexInt = dyn_cast_or_null<IntegerAttr>(index).getValue();
auto startInt = dyn_cast_or_null<IntegerAttr>(start).getValue();
auto stepInt = dyn_cast_or_null<IntegerAttr>(step).getValue();
return IntegerAttr::get(cast<TypedAttr>(index).getType(),
startInt + stepInt * indexInt);
}
@ -2768,9 +2767,9 @@ void Torch::ConstantNumberOp::getCanonicalizationPatterns(
Value constValue;
Attribute value = op.getValueAttr();
if (auto floatValue = value.dyn_cast<mlir::FloatAttr>()) {
if (auto floatValue = dyn_cast<mlir::FloatAttr>(value)) {
constValue = rewriter.create<Torch::ConstantFloatOp>(loc, floatValue);
} else if (auto intValue = value.dyn_cast<mlir::IntegerAttr>()) {
} else if (auto intValue = dyn_cast<mlir::IntegerAttr>(value)) {
constValue = rewriter.create<Torch::ConstantIntOp>(loc, intValue);
} else {
return failure();
@ -3192,9 +3191,9 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
BinaryFloatOperatorFn f) {
double lhs, rhs;
auto parseDoubleAttribute = [](Attribute attr, double &value) -> bool {
if (auto intLhs = attr.dyn_cast_or_null<IntegerAttr>()) {
if (auto intLhs = dyn_cast_or_null<IntegerAttr>(attr)) {
value = static_cast<double>(intLhs.getValue().getSExtValue());
} else if (auto floatLhs = attr.dyn_cast_or_null<FloatAttr>()) {
} else if (auto floatLhs = dyn_cast_or_null<FloatAttr>(attr)) {
value = floatLhs.getValue().convertToDouble();
} else {
return false;
@ -3945,7 +3944,7 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) {
}
Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
BaseTensorType resultTensorType = dyn_cast<BaseTensorType>(resultType);
if (!resultTensorType || !resultTensorType.hasDtype() ||
!resultTensorType.hasSizes()) {
return nullptr;
@ -3966,11 +3965,11 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) {
return nullptr;
}
auto elementType = shapedty.getElementType();
if (elementType.isa<IntegerType>()) {
if (isa<IntegerType>(elementType)) {
Attribute attribute = IntegerAttr::get(elementType, 1);
return DenseElementsAttr::get(shapedty, attribute);
}
if (elementType.isa<FloatType>()) {
if (isa<FloatType>(elementType)) {
Attribute attribute = FloatAttr::get(elementType, 1.0);
return DenseElementsAttr::get(shapedty, attribute);
}
@ -3984,7 +3983,7 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) {
}
Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
BaseTensorType resultTensorType = dyn_cast<BaseTensorType>(resultType);
if (!resultTensorType || !resultTensorType.hasDtype() ||
!resultTensorType.hasSizes()) {
return nullptr;
@ -4006,11 +4005,11 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) {
}
auto elementType = shapedty.getElementType();
if (elementType.isa<IntegerType>()) {
if (isa<IntegerType>(elementType)) {
Attribute attribute = IntegerAttr::get(elementType, 0);
return DenseElementsAttr::get(shapedty, attribute);
}
if (elementType.isa<FloatType>()) {
if (isa<FloatType>(elementType)) {
Attribute attribute = FloatAttr::get(elementType, 0.0);
return DenseElementsAttr::get(shapedty, attribute);
}
@ -4025,7 +4024,7 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
}
Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
BaseTensorType resultTensorType = dyn_cast<BaseTensorType>(resultType);
if (!resultTensorType || !resultTensorType.hasDtype() ||
!resultTensorType.hasSizes()) {
return nullptr;
@ -4043,14 +4042,14 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
mlir::RankedTensorType::get(sizes, resultTensorType.getDtype());
auto elementType = shapedty.getElementType();
if (elementType.isa<IntegerType>()) {
if (isa<IntegerType>(elementType)) {
int64_t value = 0;
if (matchPattern(getFillValue(), m_TorchConstantInt(&value))) {
Attribute attribute = IntegerAttr::get(elementType, value);
return DenseElementsAttr::get(shapedty, attribute);
}
}
if (elementType.isa<FloatType>()) {
if (isa<FloatType>(elementType)) {
double value = 0.0;
if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) {
Attribute attribute = FloatAttr::get(elementType, value);
@ -4631,15 +4630,14 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() {
auto initialize = cast<InitializeGlobalSlotsOp>(getBody()->getTerminator());
for (Attribute symName : initialize.getSlotSymNames()) {
auto wasInserted = initializedGlobalSlots
.insert(symName.cast<FlatSymbolRefAttr>().getAttr())
.insert(cast<FlatSymbolRefAttr>(symName).getAttr())
.second;
if (!wasInserted)
return initialize.emitError("duplicate initialization of global slot: ")
<< symName;
}
auto lessThanByStringValue = [](Attribute lhs, Attribute rhs) {
return lhs.cast<StringAttr>().getValue() <
rhs.cast<StringAttr>().getValue();
return cast<StringAttr>(lhs).getValue() < cast<StringAttr>(rhs).getValue();
};
auto known = llvm::to_vector(knownGlobalSlots);
llvm::sort(known, lessThanByStringValue);
@ -4652,7 +4650,7 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() {
InFlightDiagnostic diag = initialize.emitOpError(
"must have one initializer for each global slot in the module");
for (auto knownGlobalSlot : known) {
auto symName = FlatSymbolRefAttr::get(knownGlobalSlot.cast<StringAttr>());
auto symName = FlatSymbolRefAttr::get(cast<StringAttr>(knownGlobalSlot));
if (!initializedGlobalSlots.count(knownGlobalSlot)) {
diag.attachNote(
symbolTable.lookup<GlobalSlotOp>(symName.getAttr()).getLoc())
@ -4663,7 +4661,7 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() {
if (!knownGlobalSlots.count(initializedGlobalSlot)) {
diag.attachNote().append(
"unexpected global slot initializer for non-existent global slot ",
FlatSymbolRefAttr::get(initializedGlobalSlot.cast<StringAttr>()));
FlatSymbolRefAttr::get(cast<StringAttr>(initializedGlobalSlot)));
}
}
return diag;

View File

@ -29,7 +29,7 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
// For a UnionType to be a subtype, all of its contained types must be
// subtypes.
if (auto unionType = subtype.dyn_cast<UnionType>()) {
if (auto unionType = dyn_cast<UnionType>(subtype)) {
for (auto containedType : unionType.getContainedTypes()) {
if (!isValidSubtype(containedType, type))
return false;
@ -37,17 +37,17 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
return true;
}
if (auto any = type.dyn_cast<AnyType>())
if (auto any = dyn_cast<AnyType>(type))
return true;
if (auto number = type.dyn_cast<NumberType>())
return subtype.isa<IntType>() || subtype.isa<Torch::FloatType>();
if (auto number = dyn_cast<NumberType>(type))
return isa<IntType>(subtype) || isa<Torch::FloatType>(subtype);
if (auto optional = type.dyn_cast<OptionalType>())
if (auto optional = dyn_cast<OptionalType>(type))
return isValidSubtype(subtype, optional.getContainedType()) ||
subtype.isa<Torch::NoneType>();
isa<Torch::NoneType>(subtype);
if (auto unionType = type.dyn_cast<UnionType>()) {
if (auto unionType = dyn_cast<UnionType>(type)) {
for (auto containedType : unionType.getContainedTypes()) {
if (isValidSubtype(subtype, containedType))
return true;
@ -55,10 +55,10 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
return false;
}
if (auto tuple = type.dyn_cast<Torch::TupleType>()) {
if (!subtype.isa<Torch::TupleType>())
if (auto tuple = dyn_cast<Torch::TupleType>(type)) {
if (!isa<Torch::TupleType>(subtype))
return false;
auto subtypes = subtype.cast<Torch::TupleType>().getContainedTypes();
auto subtypes = cast<Torch::TupleType>(subtype).getContainedTypes();
auto types = tuple.getContainedTypes();
if (subtypes.size() != types.size())
return false;
@ -69,14 +69,14 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
return true;
}
auto subtypeTensorType = subtype.dyn_cast<BaseTensorType>();
auto typeTensorType = type.dyn_cast<BaseTensorType>();
auto subtypeTensorType = dyn_cast<BaseTensorType>(subtype);
auto typeTensorType = dyn_cast<BaseTensorType>(type);
if (subtypeTensorType && typeTensorType) {
// Check that both tensors have the same `BaseTensorType` subtype.
// TODO: This is not subtyping according to PEP 483. See description
// of NonValueTensorType.
if (subtypeTensorType.isa<ValueTensorType>() !=
typeTensorType.isa<ValueTensorType>())
if (isa<ValueTensorType>(subtypeTensorType) !=
isa<ValueTensorType>(typeTensorType))
return false;
// `type` must not have more static information than `subtype`, and `type`
@ -181,23 +181,23 @@ void Torch::UnionType::print(AsmPrinter &printer) const {
static bool isValidTorchDtype(Type dtype) {
// For complex types, get the underlying element type
if (dtype.isa<ComplexType>()) {
dtype = dtype.cast<ComplexType>().getElementType();
if (isa<ComplexType>(dtype)) {
dtype = cast<ComplexType>(dtype).getElementType();
}
// Torch quantized types.
if (dtype.isa<Torch::QInt8Type, Torch::QUInt8Type, Torch::QInt32Type>())
if (isa<Torch::QInt8Type, Torch::QUInt8Type, Torch::QInt32Type>(dtype))
return true;
// Builtin floating point types.
if (dtype.isa<Float16Type, BFloat16Type, Float32Type, Float64Type>())
if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(dtype))
return true;
if (dtype.isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
return true;
if (dtype.isa<Torch::StringType>())
if (isa<Torch::StringType>(dtype))
return true;
// Builtin integer types.
if (IntegerType type = dtype.dyn_cast<IntegerType>()) {
if (IntegerType type = dyn_cast<IntegerType>(dtype)) {
if (type.isSignless() && type.getWidth() == 1)
return true;
if (type.isSigned()) {
@ -273,7 +273,7 @@ verifyTensorType(function_ref<InFlightDiagnostic()> emitError,
}
}
}
if (!optionalSparsity.isa<sparse_tensor::SparseTensorEncodingAttr>()) {
if (!isa<sparse_tensor::SparseTensorEncodingAttr>(optionalSparsity)) {
emitError() << "invalid sparsity encoding attribute";
return failure();
}
@ -441,12 +441,12 @@ ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
}
static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
if (auto floatType = dtype.dyn_cast<mlir::FloatType>()) {
if (auto floatType = dyn_cast<mlir::FloatType>(dtype)) {
return dtype;
} else if (auto integerType = dtype.dyn_cast<IntegerType>()) {
} else if (auto integerType = dyn_cast<IntegerType>(dtype)) {
return IntegerType::get(context, integerType.getWidth(),
IntegerType::Signless);
} else if (dtype.isa<mlir::ComplexType>()) {
} else if (isa<mlir::ComplexType>(dtype)) {
return dtype;
}
@ -502,8 +502,8 @@ void ValueTensorType::print(AsmPrinter &printer) const {
}
Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
assert(((lhs.isa<ValueTensorType>() && rhs.isa<ValueTensorType>()) ||
(lhs.isa<NonValueTensorType>() && rhs.isa<NonValueTensorType>())) &&
assert(((isa<ValueTensorType>(lhs) && isa<ValueTensorType>(rhs)) ||
(isa<NonValueTensorType>(lhs) && isa<NonValueTensorType>(rhs))) &&
"expected lhs and rhs to have same sense of value semantics");
// First, calculate the dtype.
@ -566,21 +566,21 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
// linkage) and the predicates themselves can't be added/used in the
// specification of the parameters of the Torch_DictType.
static bool isAnyTorchDictKeyType(Type type) {
return type.isa<Torch::AnyType>() || type.isa<Torch::IntType>() ||
type.isa<Torch::BoolType>() || type.isa<Torch::FloatType>() ||
type.isa<Torch::StringType>() || type.isa<Torch::BaseTensorType>();
return isa<Torch::AnyType>(type) || isa<Torch::IntType>(type) ||
isa<Torch::BoolType>(type) || isa<Torch::FloatType>(type) ||
isa<Torch::StringType>(type) || isa<Torch::BaseTensorType>(type);
}
static bool isAnyTorchType(Type type) {
return isValidSubtype(type, Torch::NumberType::get(type.getContext())) ||
type.isa<Torch::BaseTensorType>() || type.isa<Torch::AnyType>() ||
type.isa<Torch::BoolType>() || type.isa<Torch::DictType>() ||
type.isa<Torch::DeviceType>() || type.isa<Torch::GeneratorType>() ||
type.isa<Torch::ListType>() || type.isa<Torch::LinearParamsType>() ||
type.isa<Torch::NumberType>() || type.isa<Torch::NnModuleType>() ||
type.isa<Torch::NoneType>() || type.isa<Torch::OptionalType>() ||
type.isa<Torch::StringType>() || type.isa<Torch::TupleType>() ||
type.isa<Torch::UnionType>();
isa<Torch::BaseTensorType>(type) || isa<Torch::AnyType>(type) ||
isa<Torch::BoolType>(type) || isa<Torch::DictType>(type) ||
isa<Torch::DeviceType>(type) || isa<Torch::GeneratorType>(type) ||
isa<Torch::ListType>(type) || isa<Torch::LinearParamsType>(type) ||
isa<Torch::NumberType>(type) || isa<Torch::NnModuleType>(type) ||
isa<Torch::NoneType>(type) || isa<Torch::OptionalType>(type) ||
isa<Torch::StringType>(type) || isa<Torch::TupleType>(type) ||
isa<Torch::UnionType>(type);
}
LogicalResult

View File

@ -53,7 +53,7 @@ public:
auto typeBoundAttr =
func.getArgAttrOfType<TypeAttr>(type.index(), typeBoundIdent);
Type bound = typeBoundAttr ? typeBoundAttr.getValue() : Type();
if (!bound.isa<ValueTensorType>())
if (!isa<ValueTensorType>(bound))
return rewriter.notifyMatchFailure(
func, "unimplemented: preserving aliasing for non-value-semantic "
"type bounds");
@ -72,10 +72,10 @@ public:
SmallVector<Type> newResultTypes;
for (auto type : func.getFunctionType().getResults()) {
if (auto none = type.dyn_cast<Torch::NoneType>()) {
if (auto none = dyn_cast<Torch::NoneType>(type)) {
continue;
}
if (auto tuple = type.dyn_cast<Torch::TupleType>()) {
if (auto tuple = dyn_cast<Torch::TupleType>(type)) {
llvm::append_range(newResultTypes, tuple.getContainedTypes());
continue;
}
@ -133,12 +133,12 @@ public:
int newOpResultIdx = 0;
SmallVector<Value> newResults;
for (auto type : call.getResultTypes()) {
if (type.isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(type)) {
newResults.push_back(
rewriter.create<ConstantNoneOp>(call.getLoc(), type));
continue;
}
if (type.isa<Torch::TupleType>()) {
if (isa<Torch::TupleType>(type)) {
newResults.push_back(rewriter.create<PrimTupleConstructOp>(
call.getLoc(), type, newCall.getResults()));
continue;

View File

@ -1386,7 +1386,7 @@ static Value getSoftmaxResult(OpTy op, Value self, Type resultType,
unNormalizedExp, sum);
if (resultType != accumulatorType)
result = convertTensorToDtype(rewriter, loc, result,
resultType.cast<BaseTensorType>().getDtype());
cast<BaseTensorType>(resultType).getDtype());
return result;
}
@ -1405,7 +1405,7 @@ public:
op, "expected result type to have a dtype");
}
Type resultTensorDtype = resultTensorType.getDtype();
if (!resultTensorDtype.isa<mlir::FloatType>())
if (!isa<mlir::FloatType>(resultTensorDtype))
return rewriter.notifyMatchFailure(op,
"Only support floating-point type");
@ -1980,7 +1980,7 @@ public:
}
Type dtype = resType.getDtype();
if (dtype.isa<mlir::ComplexType>()) {
if (isa<mlir::ComplexType>(dtype)) {
return rewriter.notifyMatchFailure(
op, "lowering of aten.linalg_cross for complex inputs dtype is "
"currently unimplemented");
@ -2015,7 +2015,7 @@ public:
Value none = rewriter.create<ConstantNoneOp>(loc);
// idx = torch.arange(3)
auto outType = opType.dyn_cast<BaseTensorType>();
auto outType = dyn_cast<BaseTensorType>(opType);
auto arangeType = outType.getWithSizesAndDtype(
llvm::ArrayRef<int64_t>(3),
IntegerType::get(op.getContext(), 64, IntegerType::Signed));
@ -5848,7 +5848,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
Value keepDim = op.getKeepdim();
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
Type outputType = op.getType();
BaseTensorType outputTensorType = outputType.cast<BaseTensorType>();
BaseTensorType outputTensorType = cast<BaseTensorType>(outputType);
if (!outputTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(op,
"expected result type to have a dtype");
@ -5893,7 +5893,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
Type meanDimResultType = inputTensorTy;
for (unsigned i = 0; i < dimListElements.size(); i++)
meanDimResultType = computeReductionType(
rewriter, op, meanDimResultType.cast<BaseTensorType>(),
rewriter, op, cast<BaseTensorType>(meanDimResultType),
dimListElements[i],
/*keepDim=*/true);
@ -6189,7 +6189,7 @@ public:
Location loc = op.getLoc();
Type resultType = op.getType();
BaseTensorType resultTensorType = resultType.cast<BaseTensorType>();
BaseTensorType resultTensorType = cast<BaseTensorType>(resultType);
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");

View File

@ -207,7 +207,7 @@ public:
return failure();
Type resultETy = resultTy.getDtype();
if (!resultETy.isa<mlir::FloatType>())
if (!isa<mlir::FloatType>(resultETy))
return failure();
Value lhsScale;

View File

@ -183,13 +183,13 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
}
LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
if (Value value = point.dyn_cast<Value>()) {
if (Value value = dyn_cast<Value>(point)) {
bool isSafe = isValueSafeTransferFunction(value);
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
propagateIfChanged(state, state->setSafe(isSafe));
// Handle GlobalSlotGetOp's.
if (auto opResult = value.dyn_cast<OpResult>()) {
if (auto opResult = dyn_cast<OpResult>(value)) {
if (auto globalSlotGet =
dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) {
auto *flatSymbolRefPoint = getProgramPoint<FlatSymbolRefProgramPoint>(
@ -205,7 +205,7 @@ LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
return success();
}
if (auto *genericProgramPoint = point.dyn_cast<GenericProgramPoint *>()) {
if (auto *genericProgramPoint = dyn_cast<GenericProgramPoint *>(point)) {
if (auto *flatSymbolRefPoint =
dyn_cast<FlatSymbolRefProgramPoint>(genericProgramPoint)) {
if (initializeGlobalSlotsOp) {
@ -396,7 +396,7 @@ class InlineGlobalSlotsPass
// This could be left to SymbolDCE but it's not hard to do here.
for (FlatSymbolRefAttr symName :
llvm::map_range(safeToInline, [](Attribute attr) {
return attr.cast<FlatSymbolRefAttr>();
return cast<FlatSymbolRefAttr>(attr);
})) {
auto globalSlot =
symbolTable.lookup<Torch::GlobalSlotOp>(symName.getValue());

View File

@ -46,14 +46,14 @@ static LogicalResult checkType(Operation *op, Type type,
// can statically pattern match and eliminate from the program.
// For example, a tensor operand might be optional, and the backend
// will pattern-match statically whether it is passed as a tensor or None.
if (type.isa<Torch::NoneType, Torch::StringType>())
if (isa<Torch::NoneType, Torch::StringType>(type))
return success();
// We blanket prohibit non-value-semantic tensors.
// All of our backends are currently based on value-semantic tensors, so
// we consider it our responsibility to lower all non-value-semantic tensors
// to value-semantic tensors.
if (type.isa<NonValueTensorType>()) {
if (isa<NonValueTensorType>(type)) {
if (actuallyEmitDiagnostics) {
return op
->emitError("unsupported by backend contract: non-value tensor type")
@ -84,7 +84,7 @@ static LogicalResult checkType(Operation *op, Type type,
// have an sufficiently rich system for representing PyTorch type promotion
// rules. So we consider it our responsibility to ensure that all dtypes are
// statically known.
if (auto tensorType = type.dyn_cast<ValueTensorType>()) {
if (auto tensorType = dyn_cast<ValueTensorType>(type)) {
if (!tensorType.hasSizes()) {
if (actuallyEmitDiagnostics) {
return op
@ -115,7 +115,7 @@ static LogicalResult checkType(Operation *op, Type type,
// Optional types are also in the category of types which we don't expect
// backends to dynamically compute with, but they can be pattern matched
// in many cases that are practically necessary.
if (auto optionalType = type.dyn_cast<OptionalType>()) {
if (auto optionalType = dyn_cast<OptionalType>(type)) {
// TODO: Be stricter about tensor types.
// See comment below for ListType.
if (optionalType.getContainedType().isa<ValueTensorType>())
@ -127,7 +127,7 @@ static LogicalResult checkType(Operation *op, Type type,
// backends to dynamically compute with, but they can be pattern matched
// in many cases that are practically necessary. For example, the
// strides of a convolution op are represented as a list.
if (auto listType = type.dyn_cast<ListType>()) {
if (auto listType = dyn_cast<ListType>(type)) {
// TODO: Be stricter about tensor types.
// For the moment, there are cases (such as for torch.cat) where we end
// up with `!torch.list<vtensor>` which doesn't have shape or dtype in
@ -141,7 +141,7 @@ static LogicalResult checkType(Operation *op, Type type,
// Tuple types are also in the category of types which we don't expect
// backends to dynamically compute with, but they can be pattern matched
// in many cases that are practically necessary.
if (auto tupleType = type.dyn_cast<Torch::TupleType>()) {
if (auto tupleType = dyn_cast<Torch::TupleType>(type)) {
for (auto containedType : tupleType.getContainedTypes()) {
if (failed(checkType(op, containedType, actuallyEmitDiagnostics)))
return failure();

View File

@ -140,7 +140,7 @@ public:
auto returnOp = ops.returnOp.value();
for (auto operand : llvm::enumerate(returnOp->getOperands())) {
auto type = operand.value().getType();
if (!type.isa<NonValueTensorType>())
if (!isa<NonValueTensorType>(type))
continue;
originalReturnTypes[operand.index()] = type;
}

View File

@ -38,15 +38,15 @@ static void createOverwriteTensorContents(PatternRewriter &rewriter,
}
static Type getContainerOrTensorTypeWithValueSemantics(Type type) {
if (auto optionalType = type.dyn_cast<OptionalType>()) {
if (auto optionalType = dyn_cast<OptionalType>(type)) {
Type newContainedType = getContainerOrTensorTypeWithValueSemantics(
optionalType.getContainedType());
return OptionalType::get(newContainedType);
} else if (auto listType = type.dyn_cast<ListType>()) {
} else if (auto listType = dyn_cast<ListType>(type)) {
Type newContainedType =
getContainerOrTensorTypeWithValueSemantics(listType.getContainedType());
return ListType::get(newContainedType);
} else if (auto tensorType = type.dyn_cast<NonValueTensorType>()) {
} else if (auto tensorType = dyn_cast<NonValueTensorType>(type)) {
return tensorType.getWithValueSemantics();
} else {
return nullptr;
@ -92,10 +92,10 @@ public:
SmallVector<Value> newOperands;
for (OpOperand &opOperand : op->getOpOperands()) {
Type operandType = opOperand.get().getType();
if (operandType.isa<NonValueTensorType>()) {
if (isa<NonValueTensorType>(operandType)) {
opOperand.set(rewriter.create<CopyToValueTensorOp>(op->getLoc(),
opOperand.get()));
} else if (auto listType = operandType.dyn_cast<ListType>()) {
} else if (auto listType = dyn_cast<ListType>(operandType)) {
if (!(listType.getContainedType().isa<NonValueTensorType>() ||
listType.getContainedType().isa<OptionalType>()))
continue;
@ -144,7 +144,7 @@ public:
}
opOperand.set(rewriter.create<PrimListConstructOp>(
op->getLoc(), newListType, newListElements));
} else if (auto optionalType = operandType.dyn_cast<OptionalType>()) {
} else if (auto optionalType = dyn_cast<OptionalType>(operandType)) {
// TODO: A more general way to handle the optional type is to
// introduce a `copy.to_optional_vtensor` op.
if (!optionalType.getContainedType().isa<NonValueTensorType>())
@ -450,7 +450,7 @@ struct ReduceOpVariantsPass
auto hasValueSemantics = [](Type t) {
// TODO: Make this an allowlist based on a closed torch dialect
// type system.
if (auto tensorType = t.dyn_cast<NonValueTensorType>()) {
if (auto tensorType = dyn_cast<NonValueTensorType>(t)) {
return false;
}
return true;

View File

@ -170,7 +170,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
if (operandType == desiredType)
return operand;
if (desiredType.isa<Torch::AnyType>()) {
if (isa<Torch::AnyType>(desiredType)) {
// Generator's are currently passed as Any because TorchScript cannot
// compile a function with Generator type arguments.
// Ignoring that hack, this is a correct handling of Any type should we need
@ -180,8 +180,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// The type `!torch.number` can be an `int`, `float`, or `complex`.
// TODO: Add a new type `Torch::ComplexType` to handle the complex case.
if (desiredType.isa<Torch::NumberType>() &&
operandType.isa<Torch::IntType, Torch::FloatType>()) {
if (isa<Torch::NumberType>(desiredType) &&
isa<Torch::IntType, Torch::FloatType>(operandType)) {
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}
@ -189,7 +189,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// `Scalar` inputs. At compile time, such inputs will usually be
// resolved to an `int`, `float`, or `None` so we need to derefine
// to match the library function signature.
if (auto unionType = desiredType.dyn_cast<Torch::UnionType>()) {
if (auto unionType = dyn_cast<Torch::UnionType>(desiredType)) {
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
return containedType
.isa<Torch::IntType, Torch::FloatType, Torch::NoneType>();
@ -200,8 +200,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// Operands with type `!torch.none` correspond to library function inputs with
// types like `!torch.optional<...>` or `!torch.union<..., none>`, so here the
// type is derefined to match the expected type of the library function.
if (operandType.isa<Torch::NoneType>()) {
assert(!desiredType.isa<Torch::NoneType>() &&
if (isa<Torch::NoneType>(operandType)) {
assert(!isa<Torch::NoneType>(desiredType) &&
"Don't expect library functions to have NoneType parameters");
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}
@ -211,8 +211,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// dtype of input scalars. However, this also means we sometimes have to
// manually turn `Scalar`s into `float`s when inserting the shape functions
// into the IR.
if (operandType.isa<Torch::NumberType>() &&
desiredType.isa<Torch::FloatType>()) {
if (isa<Torch::NumberType>(operandType) &&
isa<Torch::FloatType>(desiredType)) {
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
}
@ -224,8 +224,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// type).
// A case where this happens is `!torch.optional<vtensor>` ->
// `!torch.optional<list<int>>>`.
if (auto operandOptionalType = operandType.dyn_cast<Torch::OptionalType>()) {
if (desiredType.isa<Torch::OptionalType>()) {
if (auto operandOptionalType = dyn_cast<Torch::OptionalType>(operandType)) {
if (isa<Torch::OptionalType>(desiredType)) {
// if optional is None:
// return derefine(None)
// else:
@ -258,7 +258,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// If the desired type is OptionalType, then recursively adjust the operand to
// the contained type, then derefine it to `!torch.optional`. For example,
// `!torch.vtensor -> !torch.optional<list<int>>>`.
if (auto desiredOptionalType = desiredType.dyn_cast<Torch::OptionalType>()) {
if (auto desiredOptionalType = dyn_cast<Torch::OptionalType>(desiredType)) {
FailureOr<Value> adjusted = adjustFunctionArg(
b, loc, operand, desiredOptionalType.getContainedType(),
baseTransformation);
@ -267,7 +267,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
return b.create<DerefineOp>(loc, desiredType, *adjusted).getResult();
}
if (auto desiredListType = desiredType.dyn_cast<Torch::ListType>()) {
if (auto desiredListType = dyn_cast<Torch::ListType>(desiredType)) {
// Pseudocode:
//
// operand = ...
@ -311,7 +311,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// The library functions use `float` where the operator
// signature uses `Scalar` (see comments in torch_ods_gen.py for
// explanation).
if (desiredType.isa<Torch::FloatType>() &&
if (isa<Torch::FloatType>(desiredType) &&
operand.getType().isa<Torch::IntType>()) {
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
}

View File

@ -29,7 +29,7 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
// Turn every tensor into a tuple of (tensor_rank, tensor_dtype)
auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
Type desiredType) -> Value {
if (desiredType.isa<Torch::TupleType>() &&
if (isa<Torch::TupleType>(desiredType) &&
operand.getType().isa<Torch::BaseTensorType>()) {
Type intType = Torch::IntType::get(b.getContext());
Type sizeListType = Torch::ListType::get(intType);

View File

@ -38,7 +38,7 @@ shapeFunctionArgsBuilder(OpBuilder &b, Location loc,
Type desiredType) -> Value {
// The shape library functions have tensor operands replaced with
// `!torch.list<int>` types for the shape. Get the sizes.
auto desiredListType = desiredType.dyn_cast<Torch::ListType>();
auto desiredListType = dyn_cast<Torch::ListType>(desiredType);
if (!desiredListType)
return operand;
if (operand.getType().isa<Torch::BaseTensorType>() &&

View File

@ -262,13 +262,13 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
originalResultType.template dyn_cast<BaseTensorType>()) {
// If we didn't get any new information, there is nothing left for us to do.
updatedType = meetTensorTypes(originalBaseTensorType,
newResultType.cast<BaseTensorType>());
cast<BaseTensorType>(newResultType));
if (!updatedType || updatedType == originalBaseTensorType)
return rewriter.notifyMatchFailure(
calculateOp, "New type information does not refine old type");
} else if (auto originalResultType =
result.getType().template dyn_cast<Torch::NumberType>()) {
if (!newResultType.isa<Torch::FloatType, Torch::IntType>()) {
if (!isa<Torch::FloatType, Torch::IntType>(newResultType)) {
return rewriter.notifyMatchFailure(
calculateOp,
"Refinement of `NumberType` must be a `FloatType` or `IntType`");
@ -291,10 +291,10 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
}
if (!originalTypedValue) {
rewriter.setInsertionPointAfter(calculateOp);
if (originalResultType.isa<BaseTensorType>()) {
if (isa<BaseTensorType>(originalResultType)) {
originalTypedValue = rewriter.create<TensorStaticInfoCastOp>(
loc, originalResultType, result);
} else if (originalResultType.isa<Torch::NumberType>()) {
} else if (isa<Torch::NumberType>(originalResultType)) {
originalTypedValue =
rewriter.create<DerefineOp>(loc, originalResultType, result);
} else {
@ -314,14 +314,14 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
OpOperand &use = yieldValues->getOpOperand(resultNum);
Value def = use.get();
Value newYieldedValue;
if (def.isa<OpResult>() &&
def.cast<OpResult>()
if (isa<OpResult>(def) &&
cast<OpResult>(def)
.getDefiningOp()
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>()) {
newYieldedValue = def;
} else {
rewriter.setInsertionPoint(yieldValues);
if (updatedType.isa<BaseTensorType>()) {
if (isa<BaseTensorType>(updatedType)) {
newYieldedValue =
rewriter.create<TensorStaticInfoCastOp>(loc, updatedType, def);
} else {

View File

@ -53,8 +53,9 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
op, "Failed to convert `dtypeScalarType` to a builtin type");
}
impliedTypeFromDtype =
originalResultType.cast<BaseTensorType>().getWithSizesAndDtype(
originalResultType.getOptionalSizes(), *builtinType);
cast<BaseTensorType>(originalResultType)
.getWithSizesAndDtype(originalResultType.getOptionalSizes(),
*builtinType);
} else {
return rewriter.notifyMatchFailure(op,
"Unimplemented: Expected result type to "
@ -179,7 +180,7 @@ public:
}
Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType());
auto impliedTypeFromInputType =
originalResultType.cast<BaseTensorType>()
cast<BaseTensorType>(originalResultType)
.getWithSizesAndDtype(originalResultType.getOptionalSizes(),
inputType)
.cast<BaseTensorType>();

View File

@ -98,7 +98,7 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
auto originalResultType = result.getType().cast<BaseTensorType>();
auto impliedTypesFromShape =
originalResultType.cast<BaseTensorType>()
cast<BaseTensorType>(originalResultType)
.getWithSizesAndDtype(ArrayRef(sizes),
originalResultType.getOptionalDtype())
.cast<BaseTensorType>();

View File

@ -70,8 +70,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
return torch_upstream::ScalarType::QInt8;
if (type.isa<QInt32Type>())
return torch_upstream::ScalarType::QInt32;
if (type.isa<ComplexType>()) {
mlir::Type complexElemType = type.cast<ComplexType>().getElementType();
if (isa<ComplexType>(type)) {
mlir::Type complexElemType = cast<ComplexType>(type).getElementType();
if (complexElemType.isF16())
return torch_upstream::ScalarType::ComplexHalf;
if (complexElemType.isF32())
@ -84,9 +84,9 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
Type Torch::getTypeForTorchType(
MLIRContext *context, Type type,
mlir::IntegerType::SignednessSemantics signedness) {
if (type.isa<Torch::IntType>())
if (isa<Torch::IntType>(type))
return IntegerType::get(context, 64, signedness);
if (type.isa<Torch::FloatType>())
if (isa<Torch::FloatType>(type))
return Float64Type::get(context);
llvm::report_fatal_error("unhandled type for getTypeForTorchType");
}
@ -150,14 +150,14 @@ Torch::getTorchTypeForScalarType(MLIRContext *context,
Type Torch::getDefaultDtypeForTorchScalar(Type type) {
MLIRContext *context = type.getContext();
if (type.isa<Torch::FloatType>()) {
if (isa<Torch::FloatType>(type)) {
// For now, use float32 which is the initial default dtype returned by
// `torch.get_default_dtype`.
return Float32Type::get(context);
}
if (type.isa<Torch::IntType>())
if (isa<Torch::IntType>(type))
return IntegerType::get(context, 64, IntegerType::Signed);
if (type.isa<Torch::BoolType>())
if (isa<Torch::BoolType>(type))
return IntegerType::get(context, 1);
llvm_unreachable(
"getDefaultDtypeForTorchScalar called on an unsupported type");
@ -165,11 +165,11 @@ Type Torch::getDefaultDtypeForTorchScalar(Type type) {
Type Torch::getBuiltInTypeForTorchScalar(Type type) {
MLIRContext *context = type.getContext();
if (type.isa<Torch::FloatType>())
if (isa<Torch::FloatType>(type))
return Float64Type::get(context);
if (type.isa<Torch::IntType>())
if (isa<Torch::IntType>(type))
return IntegerType::get(context, 64, IntegerType::Signed);
if (type.isa<Torch::BoolType>())
if (isa<Torch::BoolType>(type))
return IntegerType::get(context, 1);
llvm_unreachable(
"getBuiltInTypeForTorchScalar called on an unsupported type");

View File

@ -62,15 +62,14 @@ Operation *TorchConversionDialect::materializeConstant(OpBuilder &builder,
Attribute value,
Type type,
Location loc) {
if (auto integerType = type.dyn_cast<Torch::IntType>())
return builder.create<Torch::ConstantIntOp>(loc, value.cast<IntegerAttr>());
if (auto integerType = dyn_cast<Torch::IntType>(type))
return builder.create<Torch::ConstantIntOp>(loc, cast<IntegerAttr>(value));
if (auto floatType = type.dyn_cast<Torch::FloatType>())
return builder.create<Torch::ConstantFloatOp>(loc, value.cast<FloatAttr>());
if (auto floatType = dyn_cast<Torch::FloatType>(type))
return builder.create<Torch::ConstantFloatOp>(loc, cast<FloatAttr>(value));
if (type.isa<Torch::BoolType>()) {
return builder.create<Torch::ConstantBoolOp>(loc,
value.cast<IntegerAttr>());
if (isa<Torch::BoolType>(type)) {
return builder.create<Torch::ConstantBoolOp>(loc, cast<IntegerAttr>(value));
}
return arith::ConstantOp::materialize(builder, value, type, loc);

View File

@ -95,7 +95,7 @@ public:
// get outputs
Type newResultType = getTypeConverter()->convertType(op.getType(0));
auto resultType = newResultType.cast<RankedTensorType>();
auto resultType = cast<RankedTensorType>(newResultType);
if (!resultType) {
return failure();
}

View File

@ -33,7 +33,7 @@ class VerifyStablehloBackendContractPass
converter.addConversion([](Type type) -> Type {
auto elemTy = type;
if (isa<TensorType>(type))
elemTy = type.cast<TensorType>().getElementType();
elemTy = cast<TensorType>(type).getElementType();
if (BaseMemRefType::isValidElementType(elemTy))
return type;
return nullptr;

View File

@ -54,11 +54,11 @@ void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); }
//===----------------------------------------------------------------------===//
static bool isArgMemRefTypeValid(Type type) {
if (auto memRefType = type.dyn_cast<MemRefType>()) {
if (auto memRefType = dyn_cast<MemRefType>(type)) {
Type elemTy = memRefType.getElementType();
if (elemTy.isa<Float16Type, Float32Type, Float64Type>()) {
return true;
} else if (auto integerTy = elemTy.dyn_cast<IntegerType>()) {
} else if (auto integerTy = dyn_cast<IntegerType>(elemTy)) {
if (integerTy.isSignlessInteger(64))
return true;
if (integerTy.isSignlessInteger(32))
@ -69,7 +69,7 @@ static bool isArgMemRefTypeValid(Type type) {
return true;
if (integerTy.isSignlessInteger(1))
return true;
} else if (auto complexTy = elemTy.dyn_cast<ComplexType>()) {
} else if (auto complexTy = dyn_cast<ComplexType>(elemTy)) {
return complexTy.getElementType().isa<Float32Type, Float64Type>();
}
}
@ -81,7 +81,7 @@ static void addEmitCInterfaceAttr(func::FuncOp func) {
}
static Type getAbiTypeForMemRef(Type type) {
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0);
return UnrankedMemRefType::get(cast<MemRefType>(type).getElementType(), 0);
}
// Helper function to get the type string for one return value like i32, f64,
@ -90,12 +90,12 @@ static Type getAbiTypeForMemRef(Type type) {
static std::string getTypeToken(Type type) {
if (type.isSignlessInteger())
return ("i" + Twine(type.getIntOrFloatBitWidth())).str();
else if (type.isa<mlir::FloatType>())
else if (isa<mlir::FloatType>(type))
return ("f" + Twine(type.getIntOrFloatBitWidth())).str();
else if (auto complexTy = type.dyn_cast<mlir::ComplexType>())
else if (auto complexTy = dyn_cast<mlir::ComplexType>(type))
return ("c" + Twine(complexTy.getElementType().getIntOrFloatBitWidth()))
.str();
else if (auto memRefType = type.dyn_cast<UnrankedMemRefType>())
else if (auto memRefType = dyn_cast<UnrankedMemRefType>(type))
return "mr" + getTypeToken(memRefType.getElementType());
llvm_unreachable(
@ -171,7 +171,7 @@ static LogicalResult mungeFunction(
for (auto en : llvm::enumerate(types)) {
Type retType = en.value();
Value retVal = op.getOperand(en.index());
if (auto memrefReturnType = retType.dyn_cast<MemRefType>()) {
if (auto memrefReturnType = dyn_cast<MemRefType>(retType)) {
auto elemType = memrefReturnType.getElementType();
retType = UnrankedMemRefType::get(elemType, 0);
// Cast to unranked memref type before sending it as a function