[NFC] Change to *cast instead of .*cast variants (#3405)

Member casts have been deprecated. Changing over a bunch of the member
cast calls to the global templated variants to remove deprecation
warnings.
pull/3407/head
Rob Suderman 2024-05-30 23:45:13 -07:00 committed by GitHub
parent 4e05e2cd1e
commit afca88a058
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
45 changed files with 551 additions and 658 deletions

View File

@ -54,13 +54,6 @@ cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLI
option(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER "Enables the ONNX C importer" OFF)
# TODO(#3299): migrate to from member x.cast<T>() to mlir::cast<T>(x).
if(MSVC)
add_compile_options(/wd4996)
else()
add_compile_options(-Wno-deprecated-declarations)
endif()
macro(torch_mlir_enable_werror)
if(TORCH_MLIR_ENABLE_WERROR_FLAG)
if(NOT MSVC)

View File

@ -125,7 +125,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::copy_if(getInputOperands(),
std::back_inserter(result),
[](OpOperand *opOperand) {
return opOperand->get().getType().template isa<MemRefType>();
return isa<MemRefType>(opOperand->get().getType());
});
return result;
}]
@ -144,7 +144,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::copy_if(getInputOperands(),
std::back_inserter(result),
[](OpOperand *opOperand) {
return opOperand->get().getType().template isa<RankedTensorType>();
return isa<RankedTensorType>(opOperand->get().getType());
});
return result;
}]
@ -200,7 +200,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::copy_if(getOutputOperands(),
std::back_inserter(result),
[](OpOperand *opOperand) {
return opOperand->get().getType().template isa<MemRefType>();
return isa<MemRefType>(opOperand->get().getType());
});
return result;
}]
@ -219,7 +219,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::copy_if(getOutputOperands(),
std::back_inserter(result),
[](OpOperand *opOperand) {
return opOperand->get().getType().template isa<RankedTensorType>();
return isa<RankedTensorType>(opOperand->get().getType());
});
return result;
}]
@ -238,7 +238,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::transform(getOutputBufferOperands(),
std::back_inserter(result),
[](OpOperand *opOperands) {
return opOperands->get().getType().cast<MemRefType>();
return cast<MemRefType>(opOperands->get().getType());
});
return result;
}]
@ -257,7 +257,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::transform(getOutputTensorOperands(),
std::back_inserter(result),
[](OpOperand *opOperands) {
return opOperands->get().getType().cast<RankedTensorType>();
return cast<RankedTensorType>(opOperands->get().getType());
});
return result;
}]
@ -318,7 +318,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
if (!opOperand->get().getType().template isa<RankedTensorType>())
if (!isa<RankedTensorType>(opOperand->get().getType()))
return false;
if (opOperand->getOperandNumber() < $_op.getNumInputs())
return true;
@ -334,7 +334,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
if (!opOperand->get().getType().template isa<RankedTensorType>())
if (!isa<RankedTensorType>(opOperand->get().getType()))
return false;
if (opOperand->getOperandNumber() >= $_op.getNumInputs())
return true;
@ -367,7 +367,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
if (auto shapedType =
opOperand->get().getType().template dyn_cast<ShapedType>())
dyn_cast<ShapedType>(opOperand->get().getType()))
return shapedType.getRank();
return 0;
}]
@ -383,7 +383,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
if (auto shapedType =
opOperand->get().getType().template dyn_cast<ShapedType>())
dyn_cast<ShapedType>(opOperand->get().getType()))
return shapedType.getShape();
return {};
}]
@ -398,7 +398,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
return !opOperand->get().getType().template isa<ShapedType>();
return !isa<ShapedType>(opOperand->get().getType());
}]
>,
//===------------------------------------------------------------------===//
@ -416,10 +416,10 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
return this->getOperation()->getNumResults() == 0 &&
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
return isScalar(opOperand) ||
opOperand->get().getType().template isa<MemRefType>();
isa<MemRefType>(opOperand->get().getType());
}) &&
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
return opOperand->get().getType().template isa<MemRefType>();
return isa<MemRefType>(opOperand->get().getType());
});
}]
>,
@ -435,10 +435,10 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
return
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
return isScalar(opOperand) ||
opOperand->get().getType().template isa<RankedTensorType>();
isa<RankedTensorType>(opOperand->get().getType());
}) &&
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
return opOperand->get().getType().template isa<RankedTensorType>();
return isa<RankedTensorType>(opOperand->get().getType());
});
}]
>,
@ -478,8 +478,8 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
private:
void setOperandSegmentAt(unsigned idx, unsigned val) {
auto attr = (*this)->getAttr("operand_segment_sizes")
.cast<DenseIntElementsAttr>();
auto attr = cast<DenseIntElementsAttr>((*this)->getAttr("operand_segment_sizes")
);
unsigned i = 0;
auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32),
[&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });

View File

@ -88,7 +88,7 @@ def TMTensor_ScanOp : TMTensor_Op<"scan",
return getOutputOperand(0)->get();
}
ShapedType getOperandType() {
return input().getType().cast<ShapedType>();
return cast<ShapedType>(input().getType());
}
int64_t getOperandRank() {
return getOperandType().getRank();
@ -151,10 +151,10 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
int64_t getIndexDepth() {
return getInputOperand(1)
return cast<ShapedType>(getInputOperand(1)
->get()
.getType()
.cast<ShapedType>()
)
.getShape()
.back();
}
@ -164,7 +164,7 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
}
ShapedType getUpdateType() {
return updates().getType().cast<ShapedType>();
return cast<ShapedType>(updates().getType());
}
Value indices() {
@ -172,7 +172,7 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
}
ShapedType getIndicesType() {
return indices().getType().cast<ShapedType>();
return cast<ShapedType>(indices().getType());
}
Value original() {
@ -180,11 +180,11 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
}
ShapedType getOriginalType() {
return original().getType().cast<ShapedType>();
return cast<ShapedType>(original().getType());
}
int64_t getUpdateSliceRank() {
return updates().getType().cast<ShapedType>().getRank() - 1;
return cast<ShapedType>(updates().getType()).getRank() - 1;
}
bool isScalarUpdate() {
@ -224,7 +224,7 @@ def TMTensor_SortOp : TMTensor_Op<"sort",
return getOutputs()[index];
}
ShapedType getOperandType(int index) {
return operand(index).getType().cast<ShapedType>();
return cast<ShapedType>(operand(index).getType());
}
int64_t getOperandRank() {
return getOperandType(0).getRank();
@ -291,16 +291,16 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
return getOutputOperand(0)->get();
}
ShapedType getQueryType() {
return getQuery().getType().cast<ShapedType>();
return cast<ShapedType>(getQuery().getType());
}
ShapedType getKeyType() {
return getKey().getType().cast<ShapedType>();
return cast<ShapedType>(getKey().getType());
}
ShapedType getValueType() {
return getValue().getType().cast<ShapedType>();
return cast<ShapedType>(getValue().getType());
}
ShapedType getOutputType() {
return getOutput().getType().cast<ShapedType>();
return cast<ShapedType>(getOutput().getType());
}
int64_t getQueryRank() {
return getQueryType().getRank();

View File

@ -61,12 +61,12 @@ struct onnx_list_of_constant_ints_op_binder {
bool match(Operation *op) {
auto constOp = dyn_cast<Torch::OperatorOp>(op);
if (!constOp || !constOp.getName().equals("onnx.Constant"))
if (!constOp || !(constOp.getName() == "onnx.Constant"))
return false;
if (DenseResourceElementsAttr attr =
constOp->getAttr("torch.onnx.value")
.dyn_cast_or_null<DenseResourceElementsAttr>()) {
dyn_cast_or_null<DenseResourceElementsAttr>(
constOp->getAttr("torch.onnx.value"))) {
// Bytes are stored in little endian order. Big endian support will
// require swizzling.
if (!Endian::little) {

View File

@ -190,7 +190,7 @@ struct torch_list_of_optional_constant_ints_op_binder {
int64_t num;
if (matchPattern(value, m_TorchConstantInt(&num)))
bind_values.push_back(num);
else if (value.getType().isa<Torch::NoneType>())
else if (isa<Torch::NoneType>(value.getType()))
bind_values.push_back(std::nullopt);
else
return false;

View File

@ -442,8 +442,8 @@ def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [
}];
let extraClassDeclaration = [{
Type getKeyType() { return getType().cast<DictType>().getKeyType(); }
Type getValueType() { return getType().cast<DictType>().getValueType(); }
Type getKeyType() { return cast<DictType>(getType()).getKeyType(); }
Type getValueType() { return cast<DictType>(getType()).getValueType(); }
}];
}
@ -1003,7 +1003,7 @@ def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"operand is corresponding !torch.vtensor",
"result", "operand",
"$_self.cast<NonValueTensorType>().getWithValueSemantics()">,
"cast<NonValueTensorType>($_self).getWithValueSemantics()">,
]> {
let summary = "Create a !torch.tensor with the same contents as the operand";
let description = [{
@ -1036,7 +1036,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"operand is corresponding !torch.tensor",
"result", "operand",
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()">,
"cast<ValueTensorType>($_self).getWithoutValueSemantics()">,
]> {
let summary = "Create a !torch.vtensor with the same contents as the operand";
let description = [{
@ -1064,7 +1064,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [
TypesMatchWith<"overwritten tensor type is corresponding !torch.tensor of value tensor type",
"value", "overwritten",
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()">
"cast<ValueTensorType>($_self).getWithoutValueSemantics()">
]> {
let summary = "Ovewrite the contents of tensor with values from another.";
let description = [{

View File

@ -199,7 +199,7 @@ def Torch_ValueTensorType : AnyTorchTensorType<"ValueTensor", "vtensor"> {
}
def AnyTorchTensorType : Type<
CPred<"$_self.isa<::mlir::torch::Torch::BaseTensorType>()">,
CPred<"isa<::mlir::torch::Torch::BaseTensorType>($_self)">,
"Any Torch tensor type"
>;
@ -410,11 +410,11 @@ def AnyTorchOptionalDeviceType:
def AnyTorchOptionalGeneratorType:
OptionalOf<Torch_GeneratorType, "Optional torch Generator type">;
def IsListTypePred : CPred<"$_self.isa<::mlir::torch::Torch::ListType>()">;
def IsListTypePred : CPred<"isa<::mlir::torch::Torch::ListType>($_self)">;
class ListOf<list<Type> allowedTypes, string descr> :
ContainerType<AnyTypeOf<allowedTypes>,
IsListTypePred,
"$_self.cast<::mlir::torch::Torch::ListType>().getContainedType()",
"cast<::mlir::torch::Torch::ListType>($_self).getContainedType()",
descr, "::mlir::torch::Torch::ListType">;
def AnyTorchListOfTorchBoolType : ListOf<[Torch_BoolType], "Bool list type (bool[])">;

View File

@ -26,7 +26,7 @@ bool torchMlirTypeIsValidSubtype(MlirType subtype, MlirType type) {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchNnModule(MlirType t) {
return unwrap(t).isa<Torch::NnModuleType>();
return isa<Torch::NnModuleType>(unwrap(t));
}
MlirType torchMlirTorchNnModuleTypeGet(MlirContext context,
@ -43,7 +43,7 @@ MlirTypeID torchMlirTorchNnModuleTypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchOptional(MlirType t) {
return unwrap(t).isa<Torch::OptionalType>();
return isa<Torch::OptionalType>(unwrap(t));
}
MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) {
@ -64,7 +64,7 @@ MlirTypeID torchMlirTorchOptionalTypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchTuple(MlirType t) {
return unwrap(t).isa<Torch::TupleType>();
return isa<Torch::TupleType>(unwrap(t));
}
MlirType torchMlirTorchTupleTypeGet(MlirContext context,
@ -95,7 +95,7 @@ MlirTypeID torchMlirTorchTupleTypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchUnion(MlirType t) {
return unwrap(t).isa<Torch::UnionType>();
return isa<Torch::UnionType>(unwrap(t));
}
MlirType torchMlirTorchUnionTypeGet(MlirContext context,
@ -126,7 +126,7 @@ MlirTypeID torchMlirTorchUnionTypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchList(MlirType t) {
return unwrap(t).isa<Torch::ListType>();
return isa<Torch::ListType>(unwrap(t));
}
MlirType torchMlirTorchListTypeGet(MlirType containedType) {
@ -146,7 +146,7 @@ MlirTypeID torchMlirTorchListTypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchDevice(MlirType t) {
return unwrap(t).isa<Torch::DeviceType>();
return isa<Torch::DeviceType>(unwrap(t));
}
MlirType torchMlirTorchDeviceTypeGet(MlirContext context) {
@ -162,7 +162,7 @@ MlirTypeID torchMlirTorchDeviceTypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchGenerator(MlirType t) {
return unwrap(t).isa<Torch::GeneratorType>();
return isa<Torch::GeneratorType>(unwrap(t));
}
MlirType torchMlirTorchGeneratorTypeGet(MlirContext context) {
@ -178,7 +178,7 @@ MlirTypeID torchMlirTorchGeneratorTypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchBool(MlirType t) {
return unwrap(t).isa<Torch::BoolType>();
return isa<Torch::BoolType>(unwrap(t));
}
MlirType torchMlirTorchBoolTypeGet(MlirContext context) {
@ -194,7 +194,7 @@ MlirTypeID torchMlirTorchBoolTypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchInt(MlirType t) {
return unwrap(t).isa<Torch::IntType>();
return isa<Torch::IntType>(unwrap(t));
}
MlirType torchMlirTorchIntTypeGet(MlirContext context) {
@ -210,7 +210,7 @@ MlirTypeID torchMlirTorchIntTypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchFloat(MlirType t) {
return unwrap(t).isa<Torch::FloatType>();
return isa<Torch::FloatType>(unwrap(t));
}
MlirType torchMlirTorchFloatTypeGet(MlirContext context) {
@ -226,7 +226,7 @@ MlirTypeID torchMlirTorchFloatTypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchLinearParams(MlirType t) {
return unwrap(t).isa<Torch::LinearParamsType>();
return isa<Torch::LinearParamsType>(unwrap(t));
}
MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context) {
@ -242,7 +242,7 @@ MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchQInt8(MlirType t) {
return unwrap(t).isa<Torch::QInt8Type>();
return isa<Torch::QInt8Type>(unwrap(t));
}
MlirType torchMlirTorchQInt8TypeGet(MlirContext context) {
@ -258,7 +258,7 @@ MlirTypeID torchMlirTorchQInt8TypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchQUInt8(MlirType t) {
return unwrap(t).isa<Torch::QUInt8Type>();
return isa<Torch::QUInt8Type>(unwrap(t));
}
MlirType torchMlirTorchQUInt8TypeGet(MlirContext context) {
@ -274,7 +274,7 @@ MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchNonValueTensor(MlirType t) {
return unwrap(t).isa<Torch::NonValueTensorType>();
return isa<Torch::NonValueTensorType>(unwrap(t));
}
MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
@ -341,7 +341,7 @@ MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchValueTensor(MlirType t) {
return unwrap(t).isa<Torch::ValueTensorType>();
return isa<Torch::ValueTensorType>(unwrap(t));
}
MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
@ -408,7 +408,7 @@ MlirTypeID torchMlirTorchValueTensorTypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchNone(MlirType t) {
return unwrap(t).isa<Torch::NoneType>();
return isa<Torch::NoneType>(unwrap(t));
}
MlirType torchMlirTorchNoneTypeGet(MlirContext context) {
@ -424,7 +424,7 @@ MlirTypeID torchMlirTorchNoneTypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchString(MlirType t) {
return unwrap(t).isa<Torch::StringType>();
return isa<Torch::StringType>(unwrap(t));
}
MlirType torchMlirTorchStringTypeGet(MlirContext context) {
@ -440,7 +440,7 @@ MlirTypeID torchMlirTorchStringTypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchAny(MlirType t) {
return unwrap(t).isa<Torch::AnyType>();
return isa<Torch::AnyType>(unwrap(t));
}
MlirType torchMlirTorchAnyTypeGet(MlirContext context) {
@ -456,7 +456,7 @@ MlirTypeID torchMlirTorchAnyTypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchNumber(MlirType t) {
return unwrap(t).isa<Torch::NumberType>();
return isa<Torch::NumberType>(unwrap(t));
}
MlirType torchMlirTorchNumberTypeGet(MlirContext context) {
@ -472,7 +472,7 @@ MlirTypeID torchMlirTorchNumberTypeGetTypeID() {
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchDict(MlirType t) {
return unwrap(t).isa<Torch::DictType>();
return isa<Torch::DictType>(unwrap(t));
}
MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) {

View File

@ -546,12 +546,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value shuffledPaddingList =
createConstantIntList(binder, rewriter, padding);
Value zero;
if (resultTypeOut.getDtype().isa<FloatType>()) {
if (isa<FloatType>(resultTypeOut.getDtype())) {
zero = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(
std::numeric_limits<double>::lowest()));
} else if (resultTypeOut.getDtype().isa<IntegerType>()) {
} else if (isa<IntegerType>(resultTypeOut.getDtype())) {
zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(
std::numeric_limits<int64_t>::lowest()));
@ -1295,7 +1295,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.tensorResultType(resultType))
return failure();
auto inputTensorType = operand.getType().cast<Torch::ValueTensorType>();
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
if (!inputTensorType || !inputTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected input type having sizes");
@ -1509,10 +1509,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
if (!constantValue) {
auto dataTensorType = cast<Torch::ValueTensorType>(data.getType());
if (dataTensorType.getDtype().isa<IntegerType>())
if (isa<IntegerType>(dataTensorType.getDtype()))
constantValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
if (dataTensorType.getDtype().isa<FloatType>())
if (isa<FloatType>(dataTensorType.getDtype()))
constantValue = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(0.0f));

View File

@ -1023,9 +1023,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value constFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
auto size = data.getType()
.dyn_cast<Torch::ValueTensorType>()
.getOptionalSizes();
auto size =
dyn_cast<Torch::ValueTensorType>(data.getType()).getOptionalSizes();
auto f64ResultType = rewriter.getType<Torch::ValueTensorType>(
size, rewriter.getF64Type());
Value dataCast = rewriter.create<Torch::AtenToDtypeOp>(
@ -2906,8 +2905,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
scalesValueList = noneVal;
sizesValueList = getValueList(sizeOperand);
}
if (scalesValueList.getType().isa<Torch::NoneType>() &&
sizesValueList.getType().isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(scalesValueList.getType()) &&
isa<Torch::NoneType>(sizesValueList.getType())) {
return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode");
}
rewriter

View File

@ -1868,9 +1868,8 @@ public:
const TypeConverter *typeConverter = getTypeConverter();
auto input = adaptor.getSelf();
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));
SmallVector<Value> resultShape;
SmallVector<Value> offsets;
@ -2107,9 +2106,8 @@ public:
auto input = adaptor.getSelf();
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));
SmallVector<Value> resultShape;
SmallVector<Value> offsets;
@ -2343,9 +2341,8 @@ public:
op, "diagonal dimensions cannot be identical");
Type elementType = inputType.getElementType();
RankedTensorType outputType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
RankedTensorType outputType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
Location loc = op.getLoc();
Value dim1Size, dim2Size;
@ -2581,9 +2578,8 @@ public:
})
.getResult(0);
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, resultTensor);
return success();
@ -2608,9 +2604,8 @@ public:
return failure();
// Conversion is completed specified by information in the sparse tensor
// type. Thus, we can rewrite all legalizedNames to the same construct.
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>(
op, resultType, adaptor.getOperands()[0]);
return success();

View File

@ -845,7 +845,7 @@ public:
outputSizeIntValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
if (!op.getScalesH().getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getScalesH().getType())) {
// Convert float values to int values.
// int_value = (int64_t)ceil(float_value)
Value ceilVal = rewriter.create<math::CeilOp>(loc, adaptor.getScalesH());
@ -858,7 +858,7 @@ public:
scaleFactorsInt.push_back(scaleFactorVal);
}
if (!op.getScalesW().getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getScalesW().getType())) {
// Convert float values to int values.
// int_value = (int64_t)ceil(float_value)
Value ceilVal = rewriter.create<math::CeilOp>(loc, adaptor.getScalesW());
@ -1006,7 +1006,7 @@ public:
unsigned hDimOffset = 2;
SmallVector<Value, 2> scaleFactorsFloatValues;
if (!op.getScalesH().getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getScalesH().getType())) {
scaleFactorsFloatValues.push_back(adaptor.getScalesH());
} else {
auto scaleFactorVal = rewriter.create<arith::DivFOp>(
@ -1019,7 +1019,7 @@ public:
scaleFactorsFloatValues.push_back(scaleFactorVal);
}
if (!op.getScalesW().getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getScalesW().getType())) {
scaleFactorsFloatValues.push_back(adaptor.getScalesW());
} else {
auto scaleFactorVal = rewriter.create<arith::DivFOp>(

View File

@ -41,7 +41,7 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
return;
int64_t minSI = -(1 << (numBits - 1));
Value minSIValue = rewriter.create<arith::ConstantIntOp>(
loc, minSI, zp.getType().cast<mlir::IntegerType>().getWidth());
loc, minSI, cast<mlir::IntegerType>(zp.getType()).getWidth());
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
arg = torch_to_linalg::createElementwiseLinalgGeneric(
@ -1057,10 +1057,10 @@ public:
loc, getAsOpFoldResult(outDims), accumulatorDType);
Value outputTensor;
if (accumulatorDType != resultDTy && !bias.getType().isa<Torch::NoneType>())
if (accumulatorDType != resultDTy && !isa<Torch::NoneType>(bias.getType()))
bias = torch_to_linalg::convertTensorToElementType(rewriter, loc, bias,
accumulatorDType);
if (bias.getType().isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(bias.getType())) {
Value c0;
if (isa<mlir::FloatType>(accumulatorDType)) {
c0 = rewriter.create<arith::ConstantOp>(

View File

@ -409,10 +409,8 @@ public:
Value self = adaptor.getSelf();
RankedTensorType selfType = cast<RankedTensorType>(self.getType());
Type elementType = selfType.getElementType();
RankedTensorType indicesRankedTensorType =
getTypeConverter()
->convertType(op->getResult(1).getType())
.cast<RankedTensorType>();
RankedTensorType indicesRankedTensorType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(1).getType()));
// TODO: Add support for 3D inputs.
if (selfType.getRank() == 3)
@ -717,10 +715,10 @@ public:
Location loc = op->getLoc();
const TypeConverter *typeConverter = opConversionPattern.getTypeConverter();
outputType = typeConverter->convertType(op.getResult0().getType())
.template cast<RankedTensorType>();
auxTensorType = typeConverter->convertType(op.getResult1().getType())
.template cast<RankedTensorType>();
outputType = cast<RankedTensorType>(
typeConverter->convertType(op.getResult0().getType()));
auxTensorType = cast<RankedTensorType>(
typeConverter->convertType(op.getResult1().getType()));
Type auxTensorElementType = auxTensorType.getElementType();
auto smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
@ -799,8 +797,8 @@ public:
Location loc = op->getLoc();
const TypeConverter *typeConverter = opConversionPattern.getTypeConverter();
outputType = typeConverter->convertType(op.getResult().getType())
.template cast<RankedTensorType>();
outputType = cast<RankedTensorType>(
typeConverter->convertType(op.getResult().getType()));
buffVal = rewriter.create<arith::ConstantOp>(
loc, elementType, rewriter.getFloatAttr(elementType, 0));
auxTensor = rewriter.create<tensor::EmptyOp>(

View File

@ -42,9 +42,8 @@ public:
if (train)
return failure();
auto resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
adaptor.getInput());
return success();
@ -60,8 +59,8 @@ static Value toLinearIndex(OpBuilder &b, Location loc,
Value result =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(b.getI64Type()));
for (auto [index, stride] : llvm::zip(indicesIntValues, shapeIntValues)) {
assert(index.getType().isa<mlir::IntegerType>() &&
stride.getType().isa<mlir::IntegerType>() &&
assert(isa<mlir::IntegerType>(index.getType()) &&
isa<mlir::IntegerType>(stride.getType()) &&
"Input arrays to `toLinearIndex` must only contain values of type "
"`mlir::IntegerType`");
Value mul = b.create<arith::MulIOp>(loc, result, stride);
@ -129,7 +128,7 @@ public:
if (!isa<mlir::FloatType>(elemTy))
return rewriter.notifyMatchFailure(op, "This op only support float type");
if (!generator.getType().isa<Torch::NoneType>())
if (!isa<Torch::NoneType>(generator.getType()))
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");
@ -180,7 +179,7 @@ public:
b.create<arith::MulFOp>(loc, updateFloat, scale);
Value res = b.create<arith::AddFOp>(loc, updateScaled, min);
Value truncRes = res;
if (elemTy.isa<Float16Type, Float32Type>())
if (isa<Float16Type, Float32Type>(elemTy))
truncRes = b.create<arith::TruncFOp>(loc, elemTy, res);
b.create<linalg::YieldOp>(loc, truncRes);
})

View File

@ -86,11 +86,8 @@ public:
bool isUnsigned = false;
if (!isa<mlir::FloatType>(inElementType)) {
if (isa<mlir::IntegerType>(inElementType)) {
auto integerTy = op.getSelf()
.getType()
.template cast<BaseTensorType>()
.getDtype()
.template dyn_cast<mlir::IntegerType>();
auto integerTy = dyn_cast<mlir::IntegerType>(
cast<BaseTensorType>(op.getSelf().getType()).getDtype());
isUnsigned = integerTy.isUnsigned();
} else {
return rewriter.notifyMatchFailure(
@ -280,7 +277,7 @@ public:
static Value createAbsOpForNormOps(OpBuilder &b, Location loc, Value elem,
Type resultElementType) {
if (elem.getType().isa<mlir::ComplexType>()) {
if (isa<mlir::ComplexType>(elem.getType())) {
return b.create<complex::AbsOp>(loc, elem);
}
@ -376,11 +373,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
if (isa<mlir::FloatType>(resultElementType))
return b.create<arith::MaximumFOp>(loc, self, result);
else if (isa<mlir::IntegerType>(resultElementType)) {
IntegerType intType = max.getSelf()
.getType()
.cast<BaseTensorType>()
.getDtype()
.dyn_cast<mlir::IntegerType>();
IntegerType intType = dyn_cast<mlir::IntegerType>(
cast<BaseTensorType>(max.getSelf().getType()).getDtype());
if (intType.isUnsigned())
return b.create<arith::MaxUIOp>(loc, self, result);
if (intType.isSigned())
@ -393,11 +387,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
if (isa<mlir::FloatType>(resultElementType))
return b.create<arith::MinimumFOp>(loc, self, result);
else if (isa<mlir::IntegerType>(resultElementType)) {
IntegerType intType = min.getSelf()
.getType()
.cast<BaseTensorType>()
.getDtype()
.dyn_cast<mlir::IntegerType>();
IntegerType intType = dyn_cast<mlir::IntegerType>(
cast<BaseTensorType>(min.getSelf().getType()).getDtype());
if (intType.isUnsigned())
return b.create<arith::MinUIOp>(loc, self, result);
if (intType.isSigned())
@ -657,9 +648,8 @@ public:
return opInfo;
Location loc = op->getLoc();
auto resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
Type elemType = resultType.getElementType();
LogicalResult elemTypeCheck =
validateReductionElementType(op, elemType, rewriter);

View File

@ -179,15 +179,13 @@ public:
for (auto i : {TOP, VCENTER, BOTTOM}) {
for (auto j : {LEFT, HCENTER, RIGHT}) {
auto constVtile{
auto constVtile{dyn_cast_or_null<mlir::IntegerAttr>(
mlir::dyn_cast<mlir::arith::ConstantOp>(vTile[i].getDefiningOp())
.getValue()
.dyn_cast_or_null<mlir::IntegerAttr>()};
.getValue())};
auto constHtile{
auto constHtile{dyn_cast_or_null<mlir::IntegerAttr>(
mlir::dyn_cast<mlir::arith::ConstantOp>(hTile[j].getDefiningOp())
.getValue()
.dyn_cast_or_null<mlir::IntegerAttr>()};
.getValue())};
auto vSize = constVtile.getInt();
auto hSize = constHtile.getInt();
@ -369,8 +367,8 @@ public:
for (auto size : resultSize)
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
auto resultType = typeConverter->convertType(op.getType())
.template cast<RankedTensorType>();
auto resultType =
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
Type resultElementType;
if (isa<Torch::NoneType>(op.getDtype().getType())) {
resultElementType = resultType.getElementType();
@ -426,7 +424,7 @@ public:
op, "unimplemented: pin_memory must be either None or false");
// Only `none`, `contiguous` and `preserve` memory_format is supported.
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType())) {
int64_t memoryFormat;
if (!matchPattern(op.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)))
@ -441,7 +439,7 @@ public:
}
// TODO: Add support for device arg other than cpu.
if (!op.getDevice().getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getDevice().getType())) {
std::string device;
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
return rewriter.notifyMatchFailure(
@ -453,7 +451,7 @@ public:
// TODO: Add support for non-strided layout.
// torch.layout is by default strided i.e. 0.
if (!op.getLayout().getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getLayout().getType())) {
int64_t tensorLayout;
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
return rewriter.notifyMatchFailure(
@ -478,7 +476,7 @@ public:
auto resultType =
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
Type resultElementType;
if (op.getDtype().getType().isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(op.getDtype().getType())) {
resultElementType = getDefaultDtypeForTorchScalar(
Torch::FloatType::get(op->getContext()));
} else {
@ -527,7 +525,7 @@ public:
// The pin_memory should be either `False` or `none`.
bool pinMemory;
if (!op.getPinMemory().getType().isa<Torch::NoneType>() &&
if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
pinMemory)) {
return rewriter.notifyMatchFailure(
@ -536,9 +534,8 @@ public:
Location loc = op.getLoc();
const TypeConverter *typeConverter = this->getTypeConverter();
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));
Type dtype = resultType.getElementType();
Value start =
convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype);

View File

@ -138,17 +138,16 @@ public:
requires_grad = tensorFloatOp.getRequiresGrad();
}
// TODO: Dtype conversion.
if (!dtype.getType().isa<Torch::NoneType>())
if (!isa<Torch::NoneType>(dtype.getType()))
return rewriter.notifyMatchFailure(op, "Unimplemented non-None dtype");
// TODO: Device information.
if (!device.getType().isa<Torch::NoneType>())
if (!isa<Torch::NoneType>(device.getType()))
return rewriter.notifyMatchFailure(
op, "Unimplemented non-None device information");
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
Type outElementType = resultType.getElementType();
Value elemValProm =
convertScalarToDtype(rewriter, loc, elemVal, outElementType);
@ -171,9 +170,8 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
Type outElementType = resultType.getElementType();
Value elemVal = adaptor.getA();
Value elemValProm =

View File

@ -422,7 +422,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
int64_t memoryFormat;
if (!clone.getMemoryFormat().getType().isa<Torch::NoneType>() &&
if (!isa<Torch::NoneType>(clone.getMemoryFormat().getType()) &&
(!matchPattern(clone.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)) ||
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
@ -434,24 +434,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return payloadArgs[0];
}
if (auto bitwiseAndTensor = dyn_cast<AtenBitwiseAndTensorOp>(op)) {
if (bitwiseAndTensor.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(
cast<ValueTensorType>(bitwiseAndTensor.getType()).getDtype())) {
bitwiseAndTensor.emitError(
"Bitwise_And does not support floating point dtype");
return nullptr;
}
Type dtype = converter->convertType(bitwiseAndTensor.getType())
.cast<RankedTensorType>()
Type dtype = cast<RankedTensorType>(
converter->convertType(bitwiseAndTensor.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::AndIOp>(loc, lhs, rhs);
}
if (auto bitwiseAndScalar = dyn_cast<AtenBitwiseAndScalarOp>(op)) {
Type dtype = converter->convertType(bitwiseAndScalar.getType())
.cast<RankedTensorType>()
Type dtype = cast<RankedTensorType>(
converter->convertType(bitwiseAndScalar.getType()))
.getElementType();
if (!isa<mlir::IntegerType>(dtype)) {
bitwiseAndScalar.emitError(
@ -469,32 +467,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::AndIOp>(loc, self, other);
}
if (auto bitwiseOrTensor = dyn_cast<AtenBitwiseOrTensorOp>(op)) {
if (bitwiseOrTensor.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(
cast<ValueTensorType>(bitwiseOrTensor.getType()).getDtype())) {
bitwiseOrTensor.emitError(
"Bitwise_Or does not support floating point dtype");
return nullptr;
}
Type dtype = converter->convertType(bitwiseOrTensor.getType())
.cast<RankedTensorType>()
Type dtype = cast<RankedTensorType>(
converter->convertType(bitwiseOrTensor.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::OrIOp>(loc, lhs, rhs);
}
if (auto bitwiseXorTensor = dyn_cast<AtenBitwiseXorTensorOp>(op)) {
if (bitwiseXorTensor.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(
cast<ValueTensorType>(bitwiseXorTensor.getType()).getDtype())) {
bitwiseXorTensor.emitError(
"Bitwise_Xor does not support floating point dtype");
return nullptr;
}
Type dtype = converter->convertType(bitwiseXorTensor.getType())
.cast<RankedTensorType>()
Type dtype = cast<RankedTensorType>(
converter->convertType(bitwiseXorTensor.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
@ -502,8 +496,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (auto bitwiseRightShiftTensor =
dyn_cast<AtenBitwiseRightShiftTensorOp>(op)) {
Type dtype = converter->convertType(bitwiseRightShiftTensor.getType())
.cast<RankedTensorType>()
Type dtype = cast<RankedTensorType>(
converter->convertType(bitwiseRightShiftTensor.getType()))
.getElementType();
if (!isa<mlir::IntegerType>(dtype)) {
bitwiseRightShiftTensor.emitError(
@ -516,8 +510,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (auto bitwiseLeftShiftTensor =
dyn_cast<AtenBitwiseLeftShiftTensorOp>(op)) {
Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType())
.cast<RankedTensorType>()
Type dtype = cast<RankedTensorType>(
converter->convertType(bitwiseLeftShiftTensor.getType()))
.getElementType();
if (!isa<mlir::IntegerType>(dtype)) {
bitwiseLeftShiftTensor.emitError(
@ -557,7 +551,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return createEqual(b, loc, floatDtype, self, zero);
}
if (isa<AtenAbsOp>(op)) {
if (payloadArgs[0].getType().isa<IntegerType>())
if (isa<IntegerType>(payloadArgs[0].getType()))
return b.create<math::AbsIOp>(loc, payloadArgs[0]);
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
}
@ -653,20 +647,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::SelectOp>(loc, cmp, arg, zeroPoint);
}
if (auto round = dyn_cast<AtenRoundOp>(op)) {
if (!round.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(
cast<ValueTensorType>(round.getType()).getDtype())) {
round.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
return b.create<math::RoundEvenOp>(loc, payloadArgs[0]);
}
if (auto prelu = dyn_cast<AtenPreluOp>(op)) {
if (!prelu.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(
cast<ValueTensorType>(prelu.getType()).getDtype())) {
prelu.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
@ -685,10 +675,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::AddFOp>(loc, positivePart, scaledNegativePart);
}
if (auto gelu = dyn_cast<AtenGeluOp>(op)) {
if (!gelu.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(
cast<ValueTensorType>(gelu.getType()).getDtype())) {
gelu.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
@ -732,10 +720,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return nullptr;
}
if (auto geluBackward = dyn_cast<AtenGeluBackwardOp>(op)) {
if (!geluBackward.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(
cast<ValueTensorType>(geluBackward.getType()).getDtype())) {
geluBackward.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
@ -770,10 +756,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (auto hardtanhBackward = dyn_cast<AtenHardtanhBackwardOp>(op)) {
AtenHardtanhBackwardOp::Adaptor adaptor(operands);
if (!hardtanhBackward.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(
cast<ValueTensorType>(hardtanhBackward.getType()).getDtype())) {
hardtanhBackward.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
@ -967,10 +951,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) {
if (!pow.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(
cast<ValueTensorType>(pow.getType()).getDtype())) {
pow.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
@ -1047,10 +1029,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (auto lerp = dyn_cast<AtenLerpTensorOp>(op)) {
if (!lerp.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(
cast<ValueTensorType>(lerp.getType()).getDtype())) {
lerp.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
@ -1064,8 +1044,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (auto minimum = dyn_cast<AtenMinimumOp>(op)) {
Type dtype = cast<BaseTensorType>(minimum.getType()).getDtype();
Type elemTy = converter->convertType(minimum.getType())
.cast<RankedTensorType>()
Type elemTy =
cast<RankedTensorType>(converter->convertType(minimum.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy);
@ -1074,8 +1054,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (auto maximum = dyn_cast<AtenMaximumOp>(op)) {
Type dtype = cast<BaseTensorType>(maximum.getType()).getDtype();
Type elemTy = converter->convertType(maximum.getType())
.cast<RankedTensorType>()
Type elemTy =
cast<RankedTensorType>(converter->convertType(maximum.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy);
@ -1086,8 +1066,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
AtenClampOp::Adaptor adaptor(operands);
auto min = adaptor.getMin();
auto max = adaptor.getMax();
if (min.getType().isa<Torch::OptionalType>() ||
max.getType().isa<Torch::OptionalType>()) {
if (isa<Torch::OptionalType>(min.getType()) ||
isa<Torch::OptionalType>(max.getType())) {
clamp.emitError("unimplemented: runtime optional type");
return nullptr;
}
@ -1125,9 +1105,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
};
auto result = payloadArgs[0];
if (!min.getType().isa<Torch::NoneType>())
if (!isa<Torch::NoneType>(min.getType()))
result = cmpSelect(result, min, /*getMax=*/false);
if (!max.getType().isa<Torch::NoneType>())
if (!isa<Torch::NoneType>(max.getType()))
result = cmpSelect(result, max, /*getMax=*/true);
return result;
}
@ -1135,8 +1115,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
AtenClampTensorOp::Adaptor adaptor(operands);
auto min = adaptor.getMin();
auto max = adaptor.getMax();
if (min.getType().isa<Torch::OptionalType>() ||
max.getType().isa<Torch::OptionalType>()) {
if (isa<Torch::OptionalType>(min.getType()) ||
isa<Torch::OptionalType>(max.getType())) {
clampTensor.emitError("unimplemented: runtime optional type");
return nullptr;
}
@ -1145,7 +1125,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
.getElementType();
bool isMinNone = true;
auto result = payloadArgs[0];
if (!min.getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(min.getType())) {
isMinNone = false;
auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value pred;
@ -1163,7 +1143,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
result = b.create<arith::SelectOp>(loc, pred, minPromoted, result);
}
if (!max.getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(max.getType())) {
max = isMinNone ? payloadArgs[1] : payloadArgs[2];
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
Value pred;
@ -1252,8 +1232,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::DivFOp>(loc, self, other);
}
if (auto remScalar = dyn_cast<AtenRemainderScalarOp>(op)) {
Type newResultType = converter->convertType(remScalar.getType())
.cast<RankedTensorType>()
Type newResultType =
cast<RankedTensorType>(converter->convertType(remScalar.getType()))
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
@ -1272,8 +1252,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return result;
}
if (auto remTensor = dyn_cast<AtenRemainderTensorOp>(op)) {
Type newResultType = converter->convertType(remTensor.getType())
.cast<RankedTensorType>()
Type newResultType =
cast<RankedTensorType>(converter->convertType(remTensor.getType()))
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
@ -1292,8 +1272,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return result;
}
if (auto fmod = dyn_cast<AtenFmodTensorOp>(op)) {
Type newResultType = converter->convertType(fmod.getType())
.cast<RankedTensorType>()
Type newResultType =
cast<RankedTensorType>(converter->convertType(fmod.getType()))
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
@ -1420,8 +1400,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (auto bitwiseNot = dyn_cast<AtenBitwiseNotOp>(op)) {
Type elementType = converter->convertType(bitwiseNot.getType())
.cast<RankedTensorType>()
Type elementType =
cast<RankedTensorType>(converter->convertType(bitwiseNot.getType()))
.getElementType();
if (isa<mlir::FloatType>(elementType)) {
bitwiseNot.emitError("Bitwise_Not does not support floating point dtype");
@ -1607,10 +1587,9 @@ public:
Location loc = op->getLoc();
auto tensorOperands = llvm::to_vector<6>(llvm::make_filter_range(
operands, [](Value v) { return v.getType().isa<RankedTensorType>(); }));
auto resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
operands, [](Value v) { return isa<RankedTensorType>(v.getType()); }));
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
bool hadErrorCreatingPayload = false;
Value generic = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, tensorOperands, resultType.getElementType(),
@ -1657,7 +1636,7 @@ public:
return rewriter.notifyMatchFailure(op, "dim must be constant");
// TODO: Incorporate the weight argument.
if (!weight.getType().isa<mlir::torch::Torch::NoneType>())
if (!isa<mlir::torch::Torch::NoneType>(weight.getType()))
return rewriter.notifyMatchFailure(
op, "Unimplemented, the weight operand is not incorporated.");
@ -1672,9 +1651,8 @@ public:
return rewriter.notifyMatchFailure(
op, "expected input and target to be rank <= 2");
}
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
Type elementType = resultType.getElementType();
Value zeroVal = rewriter.create<arith::ConstantOp>(
@ -1948,7 +1926,7 @@ public:
Value input = adaptor.getSelf();
Value target = adaptor.getTarget();
Value weight = adaptor.getWeight();
bool weightIsNone = op.getWeight().getType().isa<Torch::NoneType>();
bool weightIsNone = isa<Torch::NoneType>(op.getWeight().getType());
Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.getIgnoreIndex());
Value totalWeight = adaptor.getTotalWeight();
@ -2069,9 +2047,8 @@ public:
})
->getResult(0);
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, gradInput);
return success();
}
@ -2214,9 +2191,8 @@ public:
LogicalResult
matchAndRewrite(TensorStaticInfoCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
adaptor.getOperand());
return success();
@ -2243,7 +2219,7 @@ public:
if (succeeded(checkNotNone(rewriter, op, eps)))
handleEps = true;
if (handleEps && !eps.getType().isa<mlir::FloatType>()) {
if (handleEps && !isa<mlir::FloatType>(eps.getType())) {
op.emitError("Logit does not support non-floating point type");
return failure();
}
@ -2317,9 +2293,8 @@ public:
LogicalResult
matchAndRewrite(AtenIntReprOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
adaptor.getSelf());
return success();
@ -2362,8 +2337,8 @@ public:
zeropoint = converter->materializeTargetConversion(
rewriter, loc, converter->convertType(zeropoint.getType()), zeropoint);
auto resultType = converter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
auto resultType = cast<RankedTensorType>(
converter->convertType(op->getResult(0).getType()));
llvm::SmallVector<Value> dynSizes;
for (auto [index, dim] : llvm::enumerate(resultType.getShape())) {
@ -2553,9 +2528,8 @@ public:
return res;
};
auto resultType = getTypeConverter()
->convertType(op.getResult().getType())
.cast<RankedTensorType>();
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult().getType()));
SmallVector<Value> resultSize{};
if (resultType.isDynamicDim(0))
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
@ -2675,7 +2649,7 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
SmallVector<Value> scaleValues,
std::string coordStr) {
auto inputType = input.getType().cast<RankedTensorType>();
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();
SmallVector<Value> indices;
@ -2725,7 +2699,7 @@ static Value BilinearInterpolate(OpBuilder &b,
SmallVector<Value> scaleValues,
std::string coordStr) {
unsigned dimOffset = 2;
auto inputType = input.getType().cast<RankedTensorType>();
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();
Value cstOneEps =
@ -2877,7 +2851,7 @@ public:
Location loc = op->getLoc();
Value input = adaptor.getInput();
auto inputType = input.getType().cast<RankedTensorType>();
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();
if (mode.substr(0, 8) == "bilinear" && inputRank != 4)
return rewriter.notifyMatchFailure(
@ -2893,7 +2867,7 @@ public:
loc, rewriter.getIntegerType(64), inputSize));
}
if (!op.getScaleFactor().getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getScaleFactor().getType())) {
bool recompScale;
if (!matchPattern(op.getRecomputeScaleFactor(),
m_TorchConstantBool(&recompScale)))

View File

@ -52,7 +52,7 @@ Value torch_to_linalg::getPaddedTensor(
Value torch_to_linalg::getZeroPaddedTensor(
Operation *op, OpBuilder &b, Value &input,
SmallVectorImpl<int64_t> &paddingInts) {
assert(input.getType().isa<RankedTensorType>() &&
assert(isa<RankedTensorType>(input.getType()) &&
"input must be RankedTensorType");
Location loc = op->getLoc();
Value c0 = b.create<arith::ConstantOp>(
@ -67,7 +67,7 @@ Value torch_to_linalg::getZeroPaddedTensor(
Value torch_to_linalg::getDynamicZeroPaddedTensor(
Operation *op, OpBuilder &b, Value &input, SmallVectorImpl<Value> &padding,
int unpaddedDims, Value pad) {
assert(input.getType().isa<RankedTensorType>() &&
assert(isa<RankedTensorType>(input.getType()) &&
"input must be RankedTensorType");
unsigned int inRank = cast<RankedTensorType>(input.getType()).getRank();
Location loc = op->getLoc();

View File

@ -252,7 +252,7 @@ public:
// "block" arguments
for (const auto &barg : enumerate(op.getRegion().front().getArguments())) {
Value to = block->getArgument(barg.index());
if (to.getType().isa<mlir::IndexType>())
if (isa<mlir::IndexType>(to.getType()))
to =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI64Type(), to);
Type targetType = to.getType();

View File

@ -146,9 +146,9 @@ public:
if (!selfType) {
return op.emitError("only Tensor types supported in StableHLO");
}
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
auto outType = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
return success();
@ -203,9 +203,9 @@ public:
auto selfTy = cast<TensorType>(self.getType());
if (!selfTy)
return op.emitError("only Tensor types supported in StableHLO");
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
auto resultTy = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
if (isa<mlir::FloatType>(resultTy.getElementType())) {
Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy);
@ -231,9 +231,9 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();
auto outType = dyn_cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
if (!outType)
return op.emitError("only Tensor types supported in StableHLO");
@ -321,9 +321,9 @@ public:
if (!lhsTy || !rhsTy)
return op.emitError("only Tensor types supported");
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
auto outTy = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
@ -354,9 +354,9 @@ public:
if (!lhsType)
return op.emitError("only Tensor types supported in StableHLO");
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
TensorType outType = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
@ -607,9 +607,9 @@ public:
if (!lhsTy)
return op.emitError("lhs must be a ranked tensor type");
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
TensorType outType = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
Type outElemTy = outType.getElementType();
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
if (!rhsTy) {
@ -917,9 +917,9 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
if (!lhsType)
return op.emitError("only Tensor types supported in StableHLO");
auto outType = OpConversionPattern<AtenPowTensorScalarOp>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
auto outType = cast<TensorType>(
OpConversionPattern<AtenPowTensorScalarOp>::getTypeConverter()
->convertType(op.getType()));
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
@ -1421,9 +1421,9 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
// Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp.
SmallVector<APFloat> zeroConstVec(
numFeatureDimSize, APFloat::getZero(inputTy.getElementType()
.cast<mlir::FloatType>()
.getFloatSemantics()));
numFeatureDimSize,
APFloat::getZero(
cast<mlir::FloatType>(inputTy.getElementType()).getFloatSemantics()));
SmallVector<APFloat> oneConstVec(
numFeatureDimSize,
APFloat(
@ -1633,9 +1633,8 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
Location loc = op->getLoc();
// Get element type of resultType as dtype
auto outType = this->getTypeConverter()
->convertType(op.getType())
.cast<RankedTensorType>();
auto outType = cast<RankedTensorType>(
this->getTypeConverter()->convertType(op.getType()));
auto dtype = outType.getElementType();
if (!isa<mlir::IntegerType>(dtype) && !isa<mlir::FloatType>(dtype)) {
return rewriter.notifyMatchFailure(
@ -1678,7 +1677,7 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
AtenConstantPadNdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<RankedTensorType>();
auto selfTy = cast<RankedTensorType>(self.getType());
auto selfElemTy = selfTy.getElementType();
int64_t rank = selfTy.getRank();
@ -2029,7 +2028,7 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<RankedTensorType>();
auto selfTy = cast<RankedTensorType>(self.getType());
if (!selfTy.hasStaticShape()) {
return op->emitError("dynamic shaped input is not supported");
}
@ -2062,7 +2061,7 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
cmpTypeAttr);
auto resTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto bcastTy = resTy.clone(rewriter.getI1Type());
auto bcastAttr = rewriter.getDenseI64ArrayAttr({selfRank - 2, selfRank - 1});
@ -2071,15 +2070,15 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
auto resElemTy = resTy.getElementType();
Value zeroTensor;
if (resElemTy.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(resElemTy)) {
auto constAttr = SplatElementsAttr::get(
resTy, llvm::APFloat::getZero(
resElemTy.cast<FloatType>().getFloatSemantics(), false));
cast<FloatType>(resElemTy).getFloatSemantics(), false));
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
} else if (resElemTy.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(resElemTy)) {
auto constAttr = SplatElementsAttr::get(
resTy,
llvm::APInt::getZero(resElemTy.cast<mlir::IntegerType>().getWidth()));
llvm::APInt::getZero(cast<mlir::IntegerType>(resElemTy).getWidth()));
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
} else {
return op.emitError("element type is not float or integer");

View File

@ -157,8 +157,8 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
Value builtinTypeStart = adaptor.getStart();
Value builtinTypeEnd = adaptor.getEnd();
if (torchTypeStart.getType().isa<OptionalType>() ||
torchTypeEnd.getType().isa<OptionalType>())
if (isa<OptionalType>(torchTypeStart.getType()) ||
isa<OptionalType>(torchTypeEnd.getType()))
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
int64_t step;
@ -349,11 +349,11 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "offsets must be a vector with static shape equal to 1");
if (!op.getPaddingIdx().getType().isa<Torch::NoneType>())
if (!isa<Torch::NoneType>(op.getPaddingIdx().getType()))
return rewriter.notifyMatchFailure(
op, "Unimplemented: padding_idx should be none");
if (!op.getPerSampleWeights().getType().isa<Torch::NoneType>())
if (!isa<Torch::NoneType>(op.getPerSampleWeights().getType()))
return rewriter.notifyMatchFailure(
op, "Unimplemented: per_sample_weights should be none");
@ -453,25 +453,22 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
loc, getTypeConverter()->convertType(op.getType(0)),
stablehloReduceOp.getResult(0), outShapeTensor);
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(1).getType())
.cast<RankedTensorType>();
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(1).getType()));
Value resultB =
createInitialValueForGatherScatterOp(op, resultType, rewriter);
if (!resultB)
return failure();
resultType = getTypeConverter()
->convertType(op->getResult(2).getType())
.cast<RankedTensorType>();
resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(2).getType()));
Value resultC =
createInitialValueForGatherScatterOp(op, resultType, rewriter);
if (!resultC)
return failure();
resultType = getTypeConverter()
->convertType(op->getResult(3).getType())
.cast<RankedTensorType>();
resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(3).getType()));
Value resultD =
createInitialValueForGatherScatterOp(op, resultType, rewriter);
if (!resultD)
@ -612,9 +609,8 @@ LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
auto input = adaptor.getSelf();
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));
SmallVector<Value> resultShape;
SmallVector<Value> offsets;

View File

@ -350,9 +350,9 @@ public:
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op,
ConvertAtenOp<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>(),
cast<RankedTensorType>(
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(
op.getType())),
output);
return success();
@ -730,9 +730,8 @@ public:
// If transposed is set to true,
// the weight shape changes to [IC, (OC//G), KH, KW]
auto weightTy = cast<RankedTensorType>(weight.getType());
auto outTy = getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();
auto outTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
if (!inputTy || !weightTy || !outTy) {
return op.emitError("input, weight and output must be ranked tensors");
}

View File

@ -216,10 +216,10 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
auto *secondIdxArg = std::next(secondValArg);
stablehlo::ComparisonTypeAttr compareTypeAttr;
if (inputTy.getElementType().isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(inputTy.getElementType())) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(inputTy.getElementType())) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
}
@ -395,9 +395,8 @@ public:
RankedTensorType inputTy = cast<RankedTensorType>(input.getType());
Type inputElemTy = inputTy.getElementType();
int64_t inputRank = inputTy.getRank();
RankedTensorType outTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();
RankedTensorType outTy = cast<RankedTensorType>(
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType()));
auto outShape = outTy.getShape();
if (inputRank <= Dim) {

View File

@ -242,10 +242,10 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
auto *secondIdxArg = std::next(secondValArg);
stablehlo::ComparisonTypeAttr compareTypeAttr;
if (inputTy.getElementType().isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(inputTy.getElementType())) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(inputTy.getElementType())) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
}
@ -535,12 +535,10 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
"AtenMaxDimOp to StableHLO");
}
RankedTensorType valResultType = getTypeConverter()
->convertType(op.getResult(0).getType())
.template cast<RankedTensorType>();
RankedTensorType idxResultType = getTypeConverter()
->convertType(op.getResult(1).getType())
.template cast<RankedTensorType>();
RankedTensorType valResultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult(0).getType()));
RankedTensorType idxResultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult(1).getType()));
Type idxElementType = idxResultType.getElementType();
if (!isa<mlir::IntegerType>(idxElementType)) {
return op.emitError("Aten.max.dim needs integer-like result");
@ -636,9 +634,8 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
auto outTy = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<RankedTensorType>();
auto outTy =
dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");

View File

@ -271,7 +271,7 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
auto getOptionalVal = [&](Value val) -> std::optional<Value> {
if (val.getType().isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(val.getType())) {
return std::nullopt;
} else {
return val;
@ -451,7 +451,7 @@ template <>
LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite(
PrimsSplitDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>();
auto selfType = dyn_cast<TensorType>(adaptor.getA().getType());
if (!selfType) {
return op.emitError("only tensor types are currently supported");
}

View File

@ -292,7 +292,7 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc,
arith::CmpIPredicate predicate = isDescending ? ge : le;
compareOp = rewriter.create<arith::CmpIOp>(
loc, predicate, block->getArgument(0), block->getArgument(1));
} else if (elementTypes[0].isa<mlir::FloatType>()) {
} else if (isa<mlir::FloatType>(elementTypes[0])) {
// Case for using arith::CmpFOp.
arith::CmpFPredicate predicate =
isDescending ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OLE;
@ -349,8 +349,8 @@ public:
b.create<TMTensor::YieldOp>(loc, updatesElement);
});
auto resultType = typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
auto resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
return success();
}
@ -381,7 +381,7 @@ public:
// Check whether the input is a 1-d tensor of integer type or not.
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
if (inputType.getRank() != 1 ||
!inputType.getElementType().isa<mlir::IntegerType>())
!isa<mlir::IntegerType>(inputType.getElementType()))
return rewriter.notifyMatchFailure(
op,
"Input tensor has to be a one-dimensional tensor of integer type.");
@ -395,7 +395,7 @@ public:
"Unimplemented: Integer width not equal to 64 are not supported.");
// TODO: Incorporate the weight argument.
if (!weights.getType().isa<mlir::torch::Torch::NoneType>())
if (!isa<mlir::torch::Torch::NoneType>(weights.getType()))
return rewriter.notifyMatchFailure(
op, "Unimplemented: the weights operand is not incorporated.");
@ -439,8 +439,8 @@ public:
indices = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
auto resultType = typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
auto resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));
Type resultElemType = resultType.getElementType();
SmallVector<Value, 1> inputSizeDynamic =
@ -686,8 +686,8 @@ public:
auto valuesType = cast<ValueTensorType>(values.getType());
int64_t inputRank = inputType.getSizes().size();
auto valuesTensorType = cast<BaseTensorType>(op.getValues().getType());
auto resultType = typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
auto resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));
if (!valuesTensorType.hasSizes())
return rewriter.notifyMatchFailure(
@ -823,10 +823,10 @@ public:
Value inputElement) {
Value yieldValue = valuesElement;
if (accumulate) {
if (inputElement.getType().isa<mlir::IntegerType>()) {
if (isa<mlir::IntegerType>(inputElement.getType())) {
yieldValue =
b.create<arith::AddIOp>(loc, inputElement, valuesElement);
} else if (inputElement.getType().isa<mlir::FloatType>()) {
} else if (isa<mlir::FloatType>(inputElement.getType())) {
yieldValue =
b.create<arith::AddFOp>(loc, inputElement, valuesElement);
} else {
@ -1042,10 +1042,10 @@ public:
[&](OpBuilder &b, Location loc, Value valuesElement,
Value inputElement) {
Value yieldValue = valuesElement;
if (inputElement.getType().isa<mlir::IntegerType>()) {
if (isa<mlir::IntegerType>(inputElement.getType())) {
yieldValue =
b.create<arith::AddIOp>(loc, inputElement, valuesElement);
} else if (inputElement.getType().isa<mlir::FloatType>()) {
} else if (isa<mlir::FloatType>(inputElement.getType())) {
yieldValue =
b.create<arith::AddFOp>(loc, inputElement, valuesElement);
} else {
@ -1204,33 +1204,33 @@ public:
Value result;
if (reduceEnum == torch_upstream::ReductionType::SUM ||
reduceEnum == torch_upstream::ReductionType::MEAN) {
if (update.getType().isa<mlir::IntegerType>()) {
if (isa<mlir::IntegerType>(update.getType())) {
result = b.create<arith::AddIOp>(loc, update, current);
} else if (update.getType().isa<mlir::FloatType>()) {
} else if (isa<mlir::FloatType>(update.getType())) {
result = b.create<arith::AddFOp>(loc, update, current);
} else {
llvm_unreachable("Only integer/float types supported!");
}
} else if (reduceEnum == torch_upstream::ReductionType::PROD) {
if (update.getType().isa<mlir::IntegerType>()) {
if (isa<mlir::IntegerType>(update.getType())) {
result = b.create<arith::MulIOp>(loc, update, current);
} else if (update.getType().isa<mlir::FloatType>()) {
} else if (isa<mlir::FloatType>(update.getType())) {
result = b.create<arith::MulFOp>(loc, update, current);
} else {
llvm_unreachable("Only integer/float types supported!");
}
} else if (reduceEnum == torch_upstream::ReductionType::MAX) {
if (update.getType().isa<mlir::IntegerType>()) {
if (isa<mlir::IntegerType>(update.getType())) {
result = b.create<arith::MaxSIOp>(loc, update, current);
} else if (update.getType().isa<mlir::FloatType>()) {
} else if (isa<mlir::FloatType>(update.getType())) {
result = b.create<arith::MaximumFOp>(loc, update, current);
} else {
llvm_unreachable("Only integer/float types supported!");
}
} else if (reduceEnum == torch_upstream::ReductionType::MIN) {
if (update.getType().isa<mlir::IntegerType>()) {
if (isa<mlir::IntegerType>(update.getType())) {
result = b.create<arith::MinSIOp>(loc, update, current);
} else if (update.getType().isa<mlir::FloatType>()) {
} else if (isa<mlir::FloatType>(update.getType())) {
result = b.create<arith::MinimumFOp>(loc, update, current);
} else {
llvm_unreachable("Only integer/float types supported!");
@ -1285,9 +1285,8 @@ public:
})
.getResult()[0];
}
auto resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
return success();
@ -1392,9 +1391,8 @@ public:
Location loc = op.getLoc();
Value input = adaptor.getSelf();
auto resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
Type elementType = resultType.getElementType();
Type inputElementType =
cast<RankedTensorType>(input.getType()).getElementType();
@ -1414,7 +1412,7 @@ public:
int64_t inputRank = resultType.getRank();
Value dtype = op.getDtype();
if (!dtype.getType().isa<Torch::NoneType>())
if (!isa<Torch::NoneType>(dtype.getType()))
return rewriter.notifyMatchFailure(
op, "unsupported: dtype argument not supported");
@ -1444,7 +1442,7 @@ public:
rewriter, loc, input, output, acc, dim, /*inclusive=*/true,
[](OpBuilder &b, Location loc, Value input, Value acc) {
Value sum =
(input.getType().isa<mlir::FloatType>()
(isa<mlir::FloatType>(input.getType())
? b.create<arith::AddFOp>(loc, input, acc)->getResult(0)
: b.create<arith::AddIOp>(loc, input, acc)->getResult(0));
b.create<TMTensor::YieldOp>(loc, sum);
@ -1472,7 +1470,7 @@ public:
cast<ShapedType>(adaptor.getQuery().getType()).getElementType();
// Verify inputs (only support defaults)
if (!mask.getType().isa<Torch::NoneType>())
if (!isa<Torch::NoneType>(mask.getType()))
return rewriter.notifyMatchFailure(op.getLoc(),
"attention masking not supported");
double dropout;
@ -1483,7 +1481,7 @@ public:
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal)
return rewriter.notifyMatchFailure(
op.getLoc(), "causal attention masking not supported");
if (!scale.getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(scale.getType())) {
double scaleFloat;
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
scaleFloat != 1.0)

View File

@ -1,5 +1,5 @@
//===----------------------------------------------------------------------===//
//
////
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
@ -47,7 +47,7 @@ public:
return rewriter.notifyMatchFailure(op,
"Only Tensor types supported in TOSA");
if (selfTy.getElementType().isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(selfTy.getElementType())) {
rewriter.replaceOpWithNewOp<TosaOpT>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
@ -99,9 +99,9 @@ public:
return rewriter.notifyMatchFailure(op,
"Only Tensor types supported in TOSA");
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
auto outTy = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
auto binaryOp =
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
@ -248,9 +248,9 @@ public:
}
// Get output type: tensor<i32/i64/f32>
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
auto outType = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
@ -373,9 +373,9 @@ public:
std::is_same<AtenOpT, AtenLtScalarOp>());
// Promote lhs and rhs dtypes for bitwise operators.
TensorType resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
TensorType resultTy = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
if (isBitwiseOp) {
lhs = tosa::promoteType(rewriter, lhs, resultTy);
rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy);
@ -416,9 +416,9 @@ public:
return rewriter.notifyMatchFailure(op,
"Only Tensor types supported in TOSA");
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
auto outType = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat())
@ -444,9 +444,9 @@ public:
}
if (isa<mlir::FloatType>(outElemTy) || isa<mlir::IntegerType>(outElemTy)) {
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
auto outType = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
auto mulOp = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
rhsTensor, /*shift=*/0);
@ -492,9 +492,9 @@ public:
"conversion in TOSA operation");
}
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
auto outType = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
// auto result;
Value result;
@ -540,7 +540,7 @@ LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
auto selfTy = cast<TensorType>(self.getType());
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
if (selfTy && isa<mlir::FloatType>(selfTy.getElementType())) {
rewriter.replaceOpWithNewOp<tosa::TanhOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
@ -557,7 +557,7 @@ LogicalResult ConvertAtenOp<AtenSigmoidOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
auto selfTy = cast<TensorType>(self.getType());
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
if (selfTy && isa<mlir::FloatType>(selfTy.getElementType())) {
rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
@ -584,7 +584,7 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
}
// Rescale the clampIn for quantized types. TBD
if (!selfTy.getElementType().isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(selfTy.getElementType())) {
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization currently supported");
}
@ -604,7 +604,7 @@ LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
Value self = adaptor.getSelf();
auto selfTy = cast<TensorType>(self.getType());
if (!selfTy.getElementType().isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(selfTy.getElementType())) {
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization currently supported");
}
@ -667,9 +667,9 @@ public:
return rewriter.notifyMatchFailure(op,
"Only Tensor types supported in TOSA");
auto outputTy = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();
auto outputTy = cast<RankedTensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
if (!outputTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor type outputs permitted for reduce_mean");
@ -828,9 +828,8 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "non-const keepdim parameter unsupported");
auto resultTy = getTypeConverter()
->convertType(op.getResult().getType())
.cast<RankedTensorType>();
auto resultTy = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult().getType()));
auto outputETy = resultTy.getElementType();
// Create a single instance of tosa.argmax.
@ -927,9 +926,9 @@ public:
return rewriter.notifyMatchFailure(op,
"Squeeze could not compute new shape");
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getResult().getType())
.template cast<RankedTensorType>();
auto resultTy = cast<RankedTensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getResult().getType()));
auto resultElemTy = resultTy.getElementType();
auto newOutputTy = RankedTensorType::get(
@ -1017,7 +1016,7 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Pow");
if (!selfTy.getElementType().isa<mlir::FloatType>())
if (!isa<mlir::FloatType>(selfTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");
@ -1624,9 +1623,9 @@ public:
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>(),
cast<RankedTensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType())),
output);
return success();
@ -1800,9 +1799,9 @@ public:
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>(),
cast<RankedTensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType())),
matmulPlusBias);
return success();
@ -1823,7 +1822,7 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Rsub");
if (!selfTy.getElementType().isa<mlir::FloatType>())
if (!isa<mlir::FloatType>(selfTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");
@ -1869,9 +1868,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
auto inputTy = cast<RankedTensorType>(input.getType());
auto weightTy = cast<RankedTensorType>(weight.getType());
auto outputTy = getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();
auto outputTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
if (!inputTy || !weightTy || !outputTy)
return rewriter.notifyMatchFailure(
@ -2208,7 +2206,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
// Note: cudnn_enabled is not handled.
// FIXME: Handle training and momentum.
if (op.getMomentum().getType().isa<Torch::NoneType>())
if (isa<Torch::NoneType>(op.getMomentum().getType()))
return rewriter.notifyMatchFailure(op, "Unsupported None for momentum");
auto meanType = dyn_cast<TensorType>(adaptor.getRunningMean().getType());
@ -2312,9 +2310,9 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
// Note: cudnn_enabled is not handled.
// FIXME: Handle the None cases for the optional parameters.
if (adaptor.getWeight().getType().isa<Torch::NoneType>())
if (isa<Torch::NoneType>(adaptor.getWeight().getType()))
return rewriter.notifyMatchFailure(op, "Unsupported None for weight");
if (adaptor.getBias().getType().isa<Torch::NoneType>())
if (isa<Torch::NoneType>(adaptor.getBias().getType()))
return rewriter.notifyMatchFailure(op, "Unsupported None for bias");
auto weightType = cast<RankedTensorType>(adaptor.getWeight().getType());
@ -2453,9 +2451,8 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
ValueTensorLiteralOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto outputTy = getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();
auto outputTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
// Tensors with integer types need to be converted to signless integer
// element type. All tensors with element types other than integer can reuse
@ -3122,7 +3119,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
auto indicesType = dyn_cast<RankedTensorType>(indices.getType());
if (!indicesType || !indicesType.getElementType().isa<IntegerType>())
if (!indicesType || !isa<IntegerType>(indicesType.getElementType()))
return rewriter.notifyMatchFailure(
op, "Indices must be of integer tensor type");
@ -3632,11 +3629,11 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
auto indexTorch = tensorsTorchType[i];
// TODO add support for none index other than i==0, like (index0, None)
// (None, index1)
if (i == 0 && indexTorch.getType().isa<Torch::NoneType>()) {
if (i == 0 && isa<Torch::NoneType>(indexTorch.getType())) {
// convert None to [0,0,0]
auto indexNext = indexTensors[i + 1];
auto indexNextTorch = tensorsTorchType[i + 1];
if (indexNextTorch.getType().isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(indexNextTorch.getType())) {
return rewriter.notifyMatchFailure(
op, "Multiple None index is not support for now.");
}
@ -3963,8 +3960,8 @@ LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
if (!selfType.hasStaticShape() || !otherType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported");
if (!selfType.getElementType().isa<mlir::FloatType>() ||
!otherType.getElementType().isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(selfType.getElementType()) ||
!isa<mlir::FloatType>(otherType.getElementType())) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only FP element type is supported");
}
@ -4058,9 +4055,8 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
const TypeConverter *typeConverter = this->getTypeConverter();
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));
// At this point all tensors should have value semantics, and hence the
// `layout` check can be ignored.
@ -4068,7 +4064,7 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
// TODO: Add support for pin_memory features.
// The pin_memory should be either `False` or `none`.
bool pinMemory;
if (!op.getPinMemory().getType().isa<Torch::NoneType>() &&
if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
pinMemory)) {
return rewriter.notifyMatchFailure(
@ -4162,10 +4158,10 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
};
const auto isIntType =
resultType.getElementType().dyn_cast_or_null<mlir::IntegerType>();
dyn_cast_or_null<mlir::IntegerType>(resultType.getElementType());
const auto isDoubleType =
resultType.getElementType().dyn_cast_or_null<mlir::FloatType>();
dyn_cast_or_null<mlir::FloatType>(resultType.getElementType());
auto maybeResult = [&]() -> std::optional<Value> {
// Integer output type, and start / end / range are all integers.
@ -4218,9 +4214,8 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
const TypeConverter *typeConverter = this->getTypeConverter();
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));
// Only supports integer operand type, because for the floating point operand
// type result tensor has to be of type `f64` which is not supported in the
@ -4323,7 +4318,7 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
}
// Only `none`, `contiguous` and `preserve` memory_format is supported.
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType())) {
int64_t memoryFormat;
if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)))
return rewriter.notifyMatchFailure(
@ -4336,9 +4331,8 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
"memory_format is supported");
}
auto resultTy = getTypeConverter()
->convertType(op.getResult().getType())
.cast<RankedTensorType>();
auto resultTy = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult().getType()));
Value result;
if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.getSelf(),
@ -4779,9 +4773,9 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();
auto outType = dyn_cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
if (!outType)
return rewriter.notifyMatchFailure(op,
@ -4841,9 +4835,9 @@ public:
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();
auto outType = dyn_cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
if (!outType || !outType.hasStaticShape())
return rewriter.notifyMatchFailure(
@ -4875,9 +4869,9 @@ public:
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();
auto outType = dyn_cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
if (!outType || !outType.hasStaticShape())
return rewriter.notifyMatchFailure(
@ -4947,9 +4941,9 @@ public:
"unimplemented: only contiguous and channels last memory "
"format is supported");
}
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();
auto outType = dyn_cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, adaptor.getSelf());
return success();
@ -5077,8 +5071,8 @@ LogicalResult ConvertAtenOp<AtenSqrtOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op,
"Only Tensor types supported in TOSA");
auto resultType = typeConverter->convertType(op.getType())
.template cast<RankedTensorType>();
auto resultType =
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
auto elementType = resultType.getElementType();
if (isa<mlir::IntegerType>(selfTy.getElementType())) {

View File

@ -813,9 +813,9 @@ convertReduceProdOp(PatternRewriter &rewriter, Operation *op,
return std::nullopt;
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
isa<mlir::quant::UniformQuantizedType>(input_type.getElementType());
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
isa<mlir::quant::UniformQuantizedType>(output_type.getElementType());
if (input_is_qtype || output_is_qtype) {
op->emitOpError("ConvertReduceProdOp: input/output tensor should "
@ -839,9 +839,9 @@ convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
return std::nullopt;
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
isa<mlir::quant::UniformQuantizedType>(input_type.getElementType());
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
isa<mlir::quant::UniformQuantizedType>(output_type.getElementType());
if (input_is_qtype != output_is_qtype) {
op->emitOpError("ConvertReduceSumOp: input/output tensor should "
@ -894,9 +894,9 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
return std::nullopt;
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
isa<mlir::quant::UniformQuantizedType>(input_type.getElementType());
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
isa<mlir::quant::UniformQuantizedType>(output_type.getElementType());
if (input_is_qtype != output_is_qtype) {
op->emitOpError("ConvertReduceSumOp: input/output tensor should "
@ -905,7 +905,7 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
}
// Only supports float type mean() if it's non-quantized
if (!input_is_qtype && !output_type.getElementType().isa<mlir::FloatType>()) {
if (!input_is_qtype && !isa<mlir::FloatType>(output_type.getElementType())) {
op->emitWarning(
"Failed convertReduceMean: input unquantized type but output element "
"not FloatType!");

View File

@ -31,7 +31,7 @@ LogicalResult verifyLinalgCompatibleTypes(Operation *op,
return false;
auto tensor = dyn_cast<ValueTensorType>(type);
return !tensor ||
tensor.toBuiltinTensor().dyn_cast_or_null<RankedTensorType>();
dyn_cast_or_null<RankedTensorType>(tensor.toBuiltinTensor());
};
bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) &&
@ -66,7 +66,7 @@ Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
// Generate IR: assert(dim >= 0 && dim < inputRank)
void assertIsValidDim(OpBuilder &b, Location loc, Value dim, Value inputRank) {
assert(dim.getType().isa<IntegerType>() &&
assert(isa<IntegerType>(dim.getType()) &&
"dim arg of assertIsValidDim must be integer type");
Value cst0 =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
@ -139,12 +139,12 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
}
Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
assert(v.getType().isa<IntegerType>() && "must be called with integer type");
assert(isa<IntegerType>(v.getType()) && "must be called with integer type");
return b.create<arith::IndexCastOp>(loc, b.getIndexType(), v);
}
Value castIndexToInt64(OpBuilder &b, Location loc, Value idx) {
assert(idx.getType().isa<IndexType>() && "must be called with integer type");
assert(isa<IndexType>(idx.getType()) && "must be called with integer type");
return b.create<arith::IndexCastOp>(loc, b.getI64Type(), idx);
}
@ -375,7 +375,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
Value torchOptionalInt, Value builtinInt,
Value defaultValue, Value dimSize) {
if (torchOptionalInt.getType().isa<Torch::NoneType>())
if (isa<Torch::NoneType>(torchOptionalInt.getType()))
return defaultValue;
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
Value positiveDim =

View File

@ -149,14 +149,12 @@ static Value getScalarIntValue(Value input, Location loc,
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
if (inputDtype.isInteger(64)) {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseIntElementsAttr>()
auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue())
.getSplatValue<int64_t>();
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val));
} else {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseIntElementsAttr>()
auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue())
.getSplatValue<bool>();
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val));
@ -191,8 +189,7 @@ static Value getScalarFloatValue(Value input, Location loc,
return nullptr;
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseFPElementsAttr>()
auto val = cast<DenseFPElementsAttr>(valueTensorLiteralOp.getValue())
.getSplatValue<FloatAttr>()
.getValueAsDouble();
return rewriter.create<Torch::ConstantFloatOp>(
@ -1946,7 +1943,7 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) {
auto resultType = dyn_cast<ValueTensorType>(getType());
if (resultType && resultType.hasDtype() &&
resultType.getDtype().isa<mlir::IntegerType>()) {
isa<mlir::IntegerType>(resultType.getDtype())) {
return getSelf();
}
return {};
@ -2136,7 +2133,7 @@ traceKnownSizeTensorType(Value value, std::optional<int64_t> dim) {
// Limit the loop count to 6 to avoid indefinite compilation times from
// unbounded IR traversals.
for (auto idx = 0; idx < 6; ++idx) {
if (!value || !value.getType().isa<BaseTensorType>())
if (!value || !isa<BaseTensorType>(value.getType()))
return failure();
auto tensorType = cast<BaseTensorType>(value.getType());
@ -2518,7 +2515,7 @@ OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) {
OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
// Constant fold int -> float conversion.
if (auto integerAttr = adaptor.getA().dyn_cast_or_null<IntegerAttr>()) {
if (auto integerAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getA())) {
return FloatAttr::get(
mlir::Float64Type::get(getContext()),
static_cast<double>(integerAttr.getValue().getSExtValue()));
@ -2535,7 +2532,7 @@ OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) {
// Constant fold float -> int conversion.
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
if (auto floatAttr = dyn_cast_or_null<FloatAttr>(adaptor.getA())) {
return IntegerAttr::get(
mlir::IntegerType::get(getContext(), 64),
static_cast<int64_t>(floatAttr.getValue().convertToDouble()));
@ -2549,7 +2546,7 @@ OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) {
OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) {
// Constant fold float -> int conversion.
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
if (auto floatAttr = dyn_cast_or_null<FloatAttr>(adaptor.getA())) {
return IntegerAttr::get(
mlir::IntegerType::get(getContext(), 64),
static_cast<long>(floatAttr.getValue().convertToDouble()));
@ -2695,9 +2692,8 @@ LogicalResult NonValueTensorLiteralOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto attr = properties.as<Properties *>()
->getValue()
.dyn_cast_or_null<ElementsAttr>();
auto attr =
dyn_cast_or_null<ElementsAttr>(properties.as<Properties *>()->getValue());
if (!attr)
return failure();
RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
@ -2723,10 +2719,10 @@ static bool areSizesAndDtypesCompatible(BaseTensorType a, BaseTensorType b) {
bool NonValueTensorLiteralOp::isCompatibleReturnTypes(TypeRange inferred,
TypeRange actual) {
if (!actual[0].isa<BaseTensorType>())
if (!isa<BaseTensorType>(actual[0]))
return false;
return areSizesAndDtypesCompatible(inferred[0].cast<BaseTensorType>(),
actual[0].cast<BaseTensorType>());
return areSizesAndDtypesCompatible(cast<BaseTensorType>(inferred[0]),
cast<BaseTensorType>(actual[0]));
}
//===----------------------------------------------------------------------===//
@ -2737,9 +2733,8 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto attr = properties.as<Properties *>()
->getValue()
.dyn_cast_or_null<ElementsAttr>();
auto attr =
dyn_cast_or_null<ElementsAttr>(properties.as<Properties *>()->getValue());
if (!attr)
return failure();
RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
@ -2760,8 +2755,8 @@ OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) {
bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs,
mlir::TypeRange outputs) {
return areSizesAndDtypesCompatible(inputs[0].cast<BaseTensorType>(),
outputs[0].cast<BaseTensorType>());
return areSizesAndDtypesCompatible(cast<BaseTensorType>(inputs[0]),
cast<BaseTensorType>(outputs[0]));
}
void TensorStaticInfoCastOp::getCanonicalizationPatterns(
@ -3072,7 +3067,7 @@ OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) {
if (!operandType)
return nullptr;
if (operandType.hasDtype()) {
bool isFloatType = operandType.getDtype().isa<mlir::FloatType>();
bool isFloatType = isa<mlir::FloatType>(operandType.getDtype());
return IntegerAttr::get(IntegerType::get(getContext(), 1), isFloatType);
}
// doesn't has dtype
@ -3130,12 +3125,12 @@ void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
int64_t start;
int64_t end;
int64_t step;
if (op.getStart().getType().isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(op.getStart().getType())) {
start = 0;
} else if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) {
return failure();
}
if (op.getEnd().getType().isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(op.getEnd().getType())) {
end = listElements.size();
} else if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) {
return failure();
@ -3228,7 +3223,7 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// things.
Value replacement = tupleConstruct.getElements()[i];
if (replacement.getType() != op.getType()) {
if (op.getType().isa<BaseTensorType>()) {
if (isa<BaseTensorType>(op.getType())) {
replacement = rewriter.create<Torch::TensorStaticInfoCastOp>(
op.getLoc(), op.getType(), replacement);
} else {
@ -3384,8 +3379,8 @@ using BinaryIntOperatorFn = std::function<int64_t(int64_t, int64_t)>;
static OpFoldResult
atenBinaryIntOperatorFoldHelper(ArrayRef<Attribute> operands,
BinaryIntOperatorFn f) {
auto intLhs = operands[0].dyn_cast_or_null<IntegerAttr>();
auto intRhs = operands[1].dyn_cast_or_null<IntegerAttr>();
auto intLhs = dyn_cast_or_null<IntegerAttr>(operands[0]);
auto intRhs = dyn_cast_or_null<IntegerAttr>(operands[1]);
if (!intLhs || !intRhs) {
return nullptr;
}
@ -3711,7 +3706,7 @@ OpFoldResult AtenAddOp::fold(FoldAdaptor adaptor) {
return nullptr;
}
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) {
if (isa<IntegerAttr>(adaptor.getA()) && isa<IntegerAttr>(adaptor.getB())) {
return atenBinaryIntOperatorFoldHelper(
adaptor.getOperands(),
[](int64_t a, int64_t b) -> int64_t { return a + b; });
@ -3730,7 +3725,7 @@ OpFoldResult AtenMulOp::fold(FoldAdaptor adaptor) {
return nullptr;
}
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) {
if (isa<IntegerAttr>(adaptor.getA()) && isa<IntegerAttr>(adaptor.getB())) {
return atenBinaryIntOperatorFoldHelper(
adaptor.getOperands(),
[](int64_t a, int64_t b) -> int64_t { return a * b; });
@ -3749,7 +3744,7 @@ OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) {
return nullptr;
}
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) {
if (isa<IntegerAttr>(adaptor.getA()) && isa<IntegerAttr>(adaptor.getB())) {
return atenBinaryIntOperatorFoldHelper(
adaptor.getOperands(),
[](int64_t a, int64_t b) -> int64_t { return a - b; });
@ -3806,7 +3801,7 @@ OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA()) {
return nullptr;
}
auto floatValue = adaptor.getA().dyn_cast_or_null<FloatAttr>();
auto floatValue = dyn_cast_or_null<FloatAttr>(adaptor.getA());
if (!floatValue) {
return nullptr;
}
@ -3834,7 +3829,7 @@ OpFoldResult AtenNegFloatOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA()) {
return nullptr;
}
auto value = adaptor.getA().dyn_cast_or_null<FloatAttr>();
auto value = dyn_cast_or_null<FloatAttr>(adaptor.getA());
if (!value) {
return nullptr;
}
@ -4487,8 +4482,8 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
if (getA() == getB())
return getA();
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
if (!lhs || !rhs)
return nullptr;
// Torch semantics are that !torch.int is 64-bit signed.
@ -4556,8 +4551,8 @@ OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) {
if (getA() == getB())
return getA();
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
if (!lhs || !rhs)
return nullptr;
// Torch semantics are that !torch.int is 64-bit signed.
@ -4644,8 +4639,8 @@ LogicalResult AtenNormScalarOp::verify() {
// Check if dtype is one of those supported by norm operation.
// ComplexType will match any torch complex types, but each float must be
// checked individually.
if (!inTensorDtype.isa<mlir::ComplexType, mlir::Float16Type,
mlir::Float32Type, mlir::Float64Type>()) {
if (!isa<mlir::ComplexType, mlir::Float16Type, mlir::Float32Type,
mlir::Float64Type>(inTensorDtype)) {
return emitOpError(
"expected a float or complex type for input tensor, but got ")
<< inTensorDtype;

View File

@ -190,8 +190,8 @@ static bool isValidTorchDtype(Type dtype) {
// Builtin floating point types.
if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(dtype))
return true;
if (dtype.isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType>(dtype))
return true;
if (isa<Torch::StringType>(dtype))
@ -228,9 +228,9 @@ Type BaseTensorType::getWithSizesAndDtypeFrom(BaseTensorType other) const {
Type BaseTensorType::getWithSizesAndDtype(
std::optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype) const {
if (isa<NonValueTensorType>())
if (mlir::isa<NonValueTensorType>(*this))
return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype);
if (isa<ValueTensorType>())
if (mlir::isa<ValueTensorType>(*this))
return ValueTensorType::get(getContext(), optionalSizes, optionalDtype);
llvm_unreachable("not a BaseTensorType!");
}
@ -248,9 +248,9 @@ Type BaseTensorType::getWithSizesAndDtypeAndSparsity(
}
ValueTensorType BaseTensorType::getWithValueSemantics() const {
if (auto tensor = dyn_cast<NonValueTensorType>())
if (auto tensor = mlir::dyn_cast<NonValueTensorType>(*this))
return tensor.getWithValueSemantics();
if (auto tensor = dyn_cast<ValueTensorType>())
if (auto tensor = mlir::dyn_cast<ValueTensorType>(*this))
return tensor;
llvm_unreachable("not a BaseTensorType!");
}

View File

@ -110,7 +110,7 @@ public:
continue;
auto it = typeBoundMap.find({call.getCallee(), operand.index()});
if (it != typeBoundMap.end()) {
if (auto valueTensorType = it->second.dyn_cast<ValueTensorType>()) {
if (auto valueTensorType = dyn_cast<ValueTensorType>(it->second)) {
newOperands.push_back(copyTensorToType(
rewriter, call->getLoc(), valueTensorType, operand.value()));
continue;
@ -215,11 +215,11 @@ static LogicalResult adjustCallingConventions(func::FuncOp func,
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
if (func.getArgAttr(i, "torch.type_bound"))
return false;
if (func.getArgumentTypes()[i].isa<Torch::NoneType>())
if (isa<Torch::NoneType>(func.getArgumentTypes()[i]))
return false;
}
for (int i = 0, e = func.getNumResults(); i != e; i++) {
if (func.getFunctionType().getResults()[i].isa<Torch::NoneType>())
if (isa<Torch::NoneType>(func.getFunctionType().getResults()[i]))
return false;
}
return true;

View File

@ -38,7 +38,7 @@ static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
if (failed(resDtype))
return false;
return resDtype->isa<mlir::FloatType>();
return isa<mlir::FloatType>(*resDtype);
}
// Helper function to compute the return type of the reduction function.
@ -99,19 +99,15 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
Operation *op, Value input, Value dim,
bool keepDim) {
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
BaseTensorType valueType =
computeReductionType(rewriter, op, cast<BaseTensorType>(input.getType()),
dim, keepDim)
.cast<BaseTensorType>();
BaseTensorType valueType = cast<BaseTensorType>(computeReductionType(
rewriter, op, cast<BaseTensorType>(input.getType()), dim, keepDim));
if (!valueType)
return nullptr;
BaseTensorType indexType =
valueType
.getWithSizesAndDtype(
cast<BaseTensorType>(valueType.getWithSizesAndDtype(
!valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
: llvm::ArrayRef(valueType.getSizes()),
IntegerType::get(op->getContext(), 64, IntegerType::Signed))
.cast<BaseTensorType>();
IntegerType::get(op->getContext(), 64, IntegerType::Signed)));
return rewriter
.create<AtenMaxDimOp>(loc, valueType, indexType, input, dim, keepDimCst)
.getValues();
@ -1059,7 +1055,7 @@ public:
LogicalResult matchAndRewrite(AtenEyeMOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto outType = op.getType().dyn_cast<BaseTensorType>();
auto outType = dyn_cast<BaseTensorType>(op.getType());
if (!outType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
@ -1659,11 +1655,9 @@ public:
unsigned inputRank = *maybeInputRank;
if (!indicesTensorType.hasSizes())
return failure();
BaseTensorType valueTensorType =
inputType
.getWithSizesAndDtype(indicesTensorType.getOptionalSizes(),
inputType.getOptionalDtype())
.cast<BaseTensorType>();
BaseTensorType valueTensorType = cast<BaseTensorType>(
inputType.getWithSizesAndDtype(indicesTensorType.getOptionalSizes(),
inputType.getOptionalDtype()));
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
// `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so
@ -1671,10 +1665,8 @@ public:
// happens on the 0th dimension.
if (isa<Torch::NoneType>(dim.getType())) {
BaseTensorType flattenType =
inputType
.getWithSizesAndDtype({kUnknownSize},
inputType.getOptionalDtype())
.cast<BaseTensorType>();
cast<BaseTensorType>(inputType.getWithSizesAndDtype(
{kUnknownSize}, inputType.getOptionalDtype()));
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputRank - 1));
@ -3003,7 +2995,7 @@ public:
bool dimIsNone = false;
int64_t dim;
Value dimValue = op.getDim();
if (dimValue.getType().isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(dimValue.getType())) {
dimIsNone = true;
dim = inputRank - 1;
} else {
@ -3887,10 +3879,9 @@ public:
gradOutputViewSizesInt[0] = kUnknownSize;
gradOutputViewSizesInt[1] = 1;
BaseTensorType gradOutputTypeForView =
gradOutputTy
.getWithSizesAndDtype(llvm::ArrayRef(gradOutputViewSizesInt),
gradOutputTy.getOptionalDtype())
.cast<BaseTensorType>();
cast<BaseTensorType>(gradOutputTy.getWithSizesAndDtype(
llvm::ArrayRef(gradOutputViewSizesInt),
gradOutputTy.getOptionalDtype()));
Value gradOutputView = rewriter.create<Torch::AtenViewOp>(
loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList);
@ -3918,10 +3909,9 @@ public:
}
BaseTensorType gradWeightTy =
inputTransposedTy
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt),
inputTransposedTy.getOptionalDtype())
.cast<BaseTensorType>();
cast<BaseTensorType>(inputTransposedTy.getWithSizesAndDtype(
llvm::ArrayRef(gradWeightSizesInt),
inputTransposedTy.getOptionalDtype()));
Value numGroup = rewriter.create<AtenSizeIntOp>(loc, input, cstZero);
gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
@ -3937,10 +3927,9 @@ public:
for (unsigned i = 0; i < gradWeightTy.getSizes().size() - 2; i++) {
gradWeightSizesInt[i + 2] = weightSizes[i + 2];
BaseTensorType gradWeightNarrowTy =
gradWeightTy
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt),
gradWeightTy.getOptionalDtype())
.cast<BaseTensorType>();
cast<BaseTensorType>(gradWeightTy.getWithSizesAndDtype(
llvm::ArrayRef(gradWeightSizesInt),
gradWeightTy.getOptionalDtype()));
Value dim = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i + 2));
@ -3970,10 +3959,9 @@ public:
gradWeightViewShapeValue);
BaseTensorType gradWeightTypeForView =
gradWeightTy
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightViewShapeInt),
gradWeightTy.getOptionalDtype())
.cast<BaseTensorType>();
cast<BaseTensorType>(gradWeightTy.getWithSizesAndDtype(
llvm::ArrayRef(gradWeightViewShapeInt),
gradWeightTy.getOptionalDtype()));
gradWeight = rewriter.create<Torch::AtenViewOp>(
loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList);
@ -3986,10 +3974,9 @@ public:
gradWeightViewShapeInt[gradWeightDimsOrder[i]]);
}
BaseTensorType gradWeightTypeForMoveDim =
gradWeightTy
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightMoveDimShape),
gradWeightTy.getOptionalDtype())
.cast<BaseTensorType>();
cast<BaseTensorType>(gradWeightTy.getWithSizesAndDtype(
llvm::ArrayRef(gradWeightMoveDimShape),
gradWeightTy.getOptionalDtype()));
gradWeight = rewriter.create<AtenMovedimIntOp>(
loc, gradWeightTypeForMoveDim, gradWeight, /*source=*/cstZero,
@ -4009,8 +3996,7 @@ public:
Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
loc, transposedType, gradOutput, cstZero, cstOne);
// Convolve input with grad_output.
if (failed(
getTransposedType(op.getResultTypes()[1].cast<BaseTensorType>(),
if (failed(getTransposedType(cast<BaseTensorType>(op.getResultTypes()[1]),
0, 1, transposedType)))
return failure();
gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
@ -4063,7 +4049,7 @@ public:
// TODO: Handle integer type operands.
auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype())) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non-floating point dtype");
}
@ -4125,7 +4111,7 @@ public:
MLIRContext *context = op.getContext();
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>() ||
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype()) ||
!isNoneOrFloatDtype(context, dtype)) {
return rewriter.notifyMatchFailure(
op, "only floating-point type is supported");
@ -4133,7 +4119,7 @@ public:
SmallVector<Value> dimListElements;
if (!getListConstructElements(dimList, dimListElements) &&
!dimList.getType().isa<Torch::NoneType>()) {
!isa<Torch::NoneType>(dimList.getType())) {
return rewriter.notifyMatchFailure(
op, "expected `dim` to be `None` or constructed from list construct");
}
@ -4215,7 +4201,7 @@ public:
return success();
}
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>())
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype()))
return rewriter.notifyMatchFailure(
op, "only support floating type input for training mode");
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
@ -4243,7 +4229,7 @@ public:
Value input = op.getInput();
Value prob = op.getP();
bool train = false;
if (!op.getTrain().getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getTrain().getType())) {
if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) {
return rewriter.notifyMatchFailure(
op, "train must be a boolean constant or none");
@ -4263,7 +4249,7 @@ public:
return success();
}
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype())) {
return rewriter.notifyMatchFailure(
op, "only support floating type input for training mode");
}
@ -4332,7 +4318,7 @@ public:
Value self = op.getSelf();
BaseTensorType inputTensorTy = cast<BaseTensorType>(self.getType());
if (!inputTensorTy.hasDtype() ||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
!isa<mlir::FloatType>(inputTensorTy.getDtype())) {
return rewriter.notifyMatchFailure(op,
"Only aten.std support floating type");
}
@ -4388,7 +4374,7 @@ public:
Value self = op.getSelf();
BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType());
if (!inputTensorType.hasDtype() ||
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
!isa<mlir::FloatType>(inputTensorType.getDtype())) {
return rewriter.notifyMatchFailure(
op, "aten.std.dim expects input tensor of floating-point type");
}
@ -4413,7 +4399,7 @@ public:
Value self = op.getSelf();
BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType());
if (!inputTensorType.hasDtype() ||
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
!isa<mlir::FloatType>(inputTensorType.getDtype())) {
return rewriter.notifyMatchFailure(
op,
"aten.std.correction expects input tensor of floating-point type");
@ -4506,7 +4492,7 @@ public:
Value input = op.getSelf();
Type resultType = op.getType();
auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype())) {
return rewriter.notifyMatchFailure(op,
"only support floating-point type");
}
@ -4547,7 +4533,7 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,
op, "can't decompose bernoulli like ops without sizes or dtype");
}
// The `prob` is expected to be a float type tensor.
if (!probType.getDtype().isa<mlir::FloatType>()) {
if (!isa<mlir::FloatType>(probType.getDtype())) {
return rewriter.notifyMatchFailure(
op, "probabilities must be a float type tensor");
}
@ -4582,7 +4568,7 @@ public:
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getSelf();
if (!op.getGenerator().getType().isa<Torch::NoneType>())
if (!isa<Torch::NoneType>(op.getGenerator().getType()))
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");
@ -4640,7 +4626,7 @@ public:
Location loc = op.getLoc();
Value input = op.getSelf();
Value prob = op.getP();
if (!op.getGenerator().getType().isa<Torch::NoneType>())
if (!isa<Torch::NoneType>(op.getGenerator().getType()))
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");
@ -4665,7 +4651,7 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenExponentialOp op,
PatternRewriter &rewriter) const override {
if (!op.getGenerator().getType().isa<Torch::NoneType>())
if (!isa<Torch::NoneType>(op.getGenerator().getType()))
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");
@ -4706,7 +4692,7 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNormalFunctionalOp op,
PatternRewriter &rewriter) const override {
if (!op.getGenerator().getType().isa<Torch::NoneType>())
if (!isa<Torch::NoneType>(op.getGenerator().getType()))
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");
@ -4984,10 +4970,10 @@ class DecomposeAtenNativeLayerNormOp
Value weight = op.getWeight();
Value bias = op.getBias();
if (!weight.getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(weight.getType())) {
out = rewriter.create<AtenMulTensorOp>(loc, out.getType(), out, weight);
}
if (!bias.getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(bias.getType())) {
out =
rewriter.create<AtenAddTensorOp>(loc, out.getType(), out, bias, one);
}
@ -5238,13 +5224,13 @@ class DecomposeAtenNativeGroupNormOp
loc, ListType::get(IntType::get(context)), viewShape);
Value groupNormOutput = reshapedOutput;
if (!weight.getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(weight.getType())) {
auto weightReshaped = rewriter.create<AtenViewOp>(
loc, baseType, weight, /*shape=*/viewShapeSizeList);
groupNormOutput = rewriter.create<AtenMulTensorOp>(
loc, inputType, groupNormOutput, weightReshaped);
}
if (!bias.getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(bias.getType())) {
auto biasReshaped = rewriter.create<AtenViewOp>(
loc, baseType, bias, /*shape=*/viewShapeSizeList);
groupNormOutput = rewriter.create<AtenAddTensorOp>(
@ -5297,8 +5283,8 @@ class DecomposeAtenNativeBatchNormOp
// In the inference mode, the `runningMean` and `runningVar` must not be
// None.
if (runningMean.getType().isa<Torch::NoneType>() ||
runningVar.getType().isa<Torch::NoneType>())
if (isa<Torch::NoneType>(runningMean.getType()) ||
isa<Torch::NoneType>(runningVar.getType()))
return rewriter.notifyMatchFailure(
op, "running stats must not be None in inference mode");
@ -5354,7 +5340,7 @@ class DecomposeAtenNativeBatchNormOp
// 2. bias = bias.view(1, C, 1?, 1?, 1?)
// 3. output = normalizedInput * weight + bias
Value batchNormOutput = normalizedInput;
if (!weight.getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(weight.getType())) {
// Rank of `weight` must be exactly 1.
std::optional<unsigned> weightRank = getTensorRank(weight);
if (!weightRank || *weightRank != 1)
@ -5364,7 +5350,7 @@ class DecomposeAtenNativeBatchNormOp
batchNormOutput = rewriter.create<AtenMulTensorOp>(
loc, batchNormOutput.getType(), batchNormOutput, weight);
}
if (!bias.getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(bias.getType())) {
// Rank of `bias` must be exactly 1.
std::optional<unsigned> biasRank = getTensorRank(bias);
if (!biasRank || *biasRank != 1)
@ -5444,7 +5430,7 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
Value dtype = op.getDtype();
if (dtype.getType().isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(dtype.getType())) {
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
@ -5518,7 +5504,7 @@ public:
return transposeWeight;
};
if (bias.getType().isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(bias.getType())) {
auto weightRank = weightType.getSizes().size();
if (weightRank > 2 || weightRank <= 0)
return rewriter.notifyMatchFailure(
@ -5622,7 +5608,7 @@ public:
LogicalResult matchAndRewrite(AtenNewFullOp op,
PatternRewriter &rewriter) const override {
Value dtype = op.getDtype();
if (dtype.getType().isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(dtype.getType())) {
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
@ -5718,7 +5704,7 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
PatternRewriter &rewriter) const override {
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
Value dtype = op.getDtype();
if (dtype.getType().isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(dtype.getType())) {
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
@ -5743,9 +5729,9 @@ class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
PatternRewriter &rewriter) const override {
Value value = op.getValue();
if (value.getType().isa<Torch::OptionalType>())
if (isa<Torch::OptionalType>(value.getType()))
return rewriter.notifyMatchFailure(op, "optional type not supported");
if (value.getType().isa<Torch::NoneType>())
if (isa<Torch::NoneType>(value.getType()))
value = rewriter.create<Torch::ConstantFloatOp>(
op.getLoc(), rewriter.getF64FloatAttr(0));
@ -5765,7 +5751,7 @@ public:
LogicalResult matchAndRewrite(AtenToDtypeLayoutOp op,
PatternRewriter &rewriter) const override {
// TODO: Add support for pinMemory arg equal to `True`.
if (!op.getPinMemory().getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getPinMemory().getType())) {
bool pinMemory;
if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)))
return rewriter.notifyMatchFailure(
@ -5776,7 +5762,7 @@ public:
}
// TODO: Add support for device arg other than cpu.
if (!op.getDevice().getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getDevice().getType())) {
std::string device;
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
return rewriter.notifyMatchFailure(
@ -5788,7 +5774,7 @@ public:
// TODO: Add support for non-strided layout.
// torch.layout is by default strided i.e. 0.
if (!op.getLayout().getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getLayout().getType())) {
int64_t tensorLayout;
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
return rewriter.notifyMatchFailure(
@ -6254,7 +6240,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
Type newOutputType = outputTensorType.getWithSizesAndDtype(
outputTensorType.getSizes(), rewriter.getF64Type());
if (!inputTensorTy.hasDtype() ||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
!isa<mlir::FloatType>(inputTensorTy.getDtype())) {
return rewriter.notifyMatchFailure(
op, "support floating-point type input only");
}
@ -6391,14 +6377,14 @@ public:
PatternRewriter &rewriter) const override {
int64_t correctionValInt;
double correctionValFloat = 1.0;
if (!op.getCorrection().getType().isa<Torch::NoneType>()) {
if (op.getCorrection().getType().isa<Torch::FloatType>()) {
if (!isa<Torch::NoneType>(op.getCorrection().getType())) {
if (isa<Torch::FloatType>(op.getCorrection().getType())) {
if (!matchPattern(op.getCorrection(),
m_TorchConstantFloat(&correctionValFloat)))
return rewriter.notifyMatchFailure(
op, "Only support constant int or float correction value for "
"aten.var");
} else if (op.getCorrection().getType().isa<Torch::IntType>()) {
} else if (isa<Torch::IntType>(op.getCorrection().getType())) {
if (!matchPattern(op.getCorrection(),
m_TorchConstantInt(&correctionValInt)))
return rewriter.notifyMatchFailure(
@ -6525,11 +6511,9 @@ public:
if (!inputType.hasSizes())
return rewriter.notifyMatchFailure(
op, "Expected the input tensor to have sizes");
BaseTensorType subType =
inputType
.getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()),
resultType.getOptionalDtype())
.cast<BaseTensorType>();
BaseTensorType subType = cast<BaseTensorType>(
inputType.getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()),
resultType.getOptionalDtype()));
Value sub =
createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget());
@ -6566,7 +6550,7 @@ public:
Location loc = op->getLoc();
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value ord = op.getP();
if (ord.getType().isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(ord.getType())) {
ord = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(2.0));
}
@ -6609,10 +6593,8 @@ public:
loc, rewriter.getF64FloatAttr((double)cstHigh));
BaseTensorType floatResultType =
resultTensorType
.getWithSizesAndDtype(resultTensorType.getSizes(),
rewriter.getF32Type())
.cast<BaseTensorType>();
cast<BaseTensorType>(resultTensorType.getWithSizesAndDtype(
resultTensorType.getSizes(), rewriter.getF32Type()));
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, floatResultType, op.getSize(), /*dtype=*/none,
/*layout=*/op.getLayout(),
@ -6704,7 +6686,7 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimsVarOp op,
PatternRewriter &rewriter) const override {
if (!op.getOutputDtype().getType().isa<Torch::NoneType>())
if (!isa<Torch::NoneType>(op.getOutputDtype().getType()))
return rewriter.notifyMatchFailure(
op, "Unimplemented non-None dtype for prims::var op");
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
@ -6816,7 +6798,7 @@ public:
LogicalResult matchAndRewrite(AtenRandnLikeOp op,
PatternRewriter &rewriter) const override {
// Only `none`, `contiguous` and `preserve` memory_format is supported.
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType())) {
int64_t memoryFormat;
if (!matchPattern(op.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)))
@ -6913,8 +6895,8 @@ public:
op.getDevice(), op.getPinMemory());
// calculate (end - start) / (steps - 1)
Value sub;
if (op.getEnd().getType().isa<Torch::FloatType>() ||
op.getStart().getType().isa<Torch::FloatType>()) {
if (isa<Torch::FloatType>(op.getEnd().getType()) ||
isa<Torch::FloatType>(op.getStart().getType())) {
sub = rewriter.create<AtenSubOp>(loc, Torch::FloatType::get(context),
op.getEnd(), op.getStart());
} else {
@ -6930,7 +6912,7 @@ public:
}
// to dtype
Value result;
if (!op.getDtype().getType().isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getDtype().getType())) {
result = rewriter.create<AtenToDtypeOp>(
loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal,
/*copy=*/falseVal, /*memory_format=*/none);
@ -7344,11 +7326,8 @@ public:
auto selfType = cast<BaseTensorType>(self.getType());
auto indexType = cast<BaseTensorType>(index.getType());
BaseTensorType srcType =
selfType
.getWithSizesAndDtype(indexType.getOptionalSizes(),
selfType.getOptionalDtype())
.cast<BaseTensorType>();
BaseTensorType srcType = cast<BaseTensorType>(selfType.getWithSizesAndDtype(
indexType.getOptionalSizes(), selfType.getOptionalDtype()));
Value src =
createInitTensor(rewriter, loc, srcType, op.getValue(), sizeList);
rewriter.replaceOpWithNewOp<AtenScatterSrcOp>(op, op.getType(), self,
@ -7372,7 +7351,7 @@ public:
"expected result type to have dtype");
}
// TODO: support complex type in future.
if (outType.getDtype().isa<mlir::ComplexType>()) {
if (isa<mlir::ComplexType>(outType.getDtype())) {
return rewriter.notifyMatchFailure(op,
"doesn't support complex type now");
}
@ -7488,7 +7467,7 @@ static FailureOr<Value> createNewIndices(Operation *op,
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
auto inputType = input.getType().cast<BaseTensorType>();
auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasSizes()) {
return failure();
}
@ -7497,7 +7476,7 @@ static FailureOr<Value> createNewIndices(Operation *op,
int64_t maxIndexRank = 0;
for (auto index : oldIndices) {
auto indexType = index.getType().dyn_cast<BaseTensorType>();
auto indexType = dyn_cast<BaseTensorType>(index.getType());
if (!indexType) // None index
continue;
if (!indexType.hasSizes())
@ -7586,15 +7565,13 @@ public:
int64_t inputRank = inputSizes.size();
auto isTensor = [](Value v) {
return v.getType().isa<Torch::BaseTensorType>();
return isa<Torch::BaseTensorType>(v.getType());
};
// directly replace aten.Index.Tensor with aten.index.Tensor_hacked_twin
if (llvm::all_of(indices, isTensor)) {
// By default, we regard the first index type as the list element type.
auto indexElemType = indices[0]
.getType()
.template cast<BaseTensorType>()
auto indexElemType = cast<BaseTensorType>(indices[0].getType())
.getWithSizesAndDtype(std::nullopt, nullptr);
auto newIndices = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(indexElemType), indices);
@ -7684,7 +7661,7 @@ public:
"failed to get elements of `indices`");
auto input = op.getSelf();
auto inputType = input.getType().template cast<BaseTensorType>();
auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "only input with shape information is supported");
@ -7693,15 +7670,13 @@ public:
int64_t inputRank = inputSizes.size();
auto isTensor = [](Value v) {
return v.getType().isa<Torch::BaseTensorType>();
return isa<Torch::BaseTensorType>(v.getType());
};
// directly replace current op with aten.index_put.hacked_twin
if (llvm::all_of(indices, isTensor)) {
// By default, we regard the first index type as the list element type.
auto indexElemType = indices[0]
.getType()
.template cast<BaseTensorType>()
auto indexElemType = cast<BaseTensorType>(indices[0].getType())
.getWithSizesAndDtype(std::nullopt, nullptr);
auto newIndex = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(indexElemType), indices);
@ -7831,7 +7806,7 @@ public:
// default ord value is 2 for vector_norm
auto ord = op.getOrd();
if (ord.getType().isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(ord.getType())) {
ord = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
}
rewriter.replaceOpWithNewOp<Torch::AtenLinalgVectorNormOp>(

View File

@ -63,8 +63,8 @@ public:
};
static bool isTypeTriviallySafe(Type type) {
return type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
Torch::StringType, Torch::NoneType, Torch::ValueTensorType>();
return isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
Torch::StringType, Torch::NoneType, Torch::ValueTensorType>(type);
}
static bool isUseTreatedWithValueSemantics(OpOperand &use) {

View File

@ -36,8 +36,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
static LogicalResult checkType(Operation *op, Type type,
bool actuallyEmitDiagnostics) {
// Allow various scalar types that backends are expected to be able to handle.
if (type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
Torch::DeviceType>())
if (isa<Torch::IntType, Torch::FloatType, Torch::BoolType, Torch::DeviceType>(
type))
return success();
// Backends are not expected to support dynamic computations on these types,

View File

@ -187,7 +187,7 @@ public:
auto it = originalReturnTypes.find(i);
if (it == originalReturnTypes.end())
continue;
auto originalType = it->second.cast<NonValueTensorType>();
auto originalType = cast<NonValueTensorType>(it->second);
rewriter.setInsertionPoint(returnOp);
Value newReturnValue = copyTensorToType(rewriter, returnOp->getLoc(),
originalType, operand.get());
@ -350,7 +350,7 @@ public:
auto it = originalTypes.find(operand.get());
if (it == originalTypes.end())
continue;
auto originalType = it->second.cast<BaseTensorType>();
auto originalType = cast<BaseTensorType>(it->second);
rewriter.setInsertionPoint(op);
Value newReturnValue = copyTensorToType(rewriter, op->getLoc(),
originalType, operand.get());

View File

@ -118,7 +118,7 @@ public:
if (auto optionalType =
dyn_cast<OptionalType>(listType.getContainedType())) {
if (!llvm::all_of(listConstruct.getElements(), [](Value val) {
return val.getType().isa<NonValueTensorType, Torch::NoneType>();
return isa<NonValueTensorType, Torch::NoneType>(val.getType());
})) {
rewriter.cancelOpModification(op);
return rewriter.notifyMatchFailure(

View File

@ -81,7 +81,7 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable(
if (name.starts_with("valsem."))
name = name.drop_front(strlen("valsem."));
if (isa<OperatorOp>(op))
name = cast<OperatorOp>(op)->getAttr("name").cast<StringAttr>().getValue();
name = cast<StringAttr>(cast<OperatorOp>(op)->getAttr("name")).getValue();
std::string libFuncName =
(getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str();
auto libFunc = library.lookupSymbol<func::FuncOp>(libFuncName);
@ -191,8 +191,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// to match the library function signature.
if (auto unionType = dyn_cast<Torch::UnionType>(desiredType)) {
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
return containedType
.isa<Torch::IntType, Torch::FloatType, Torch::NoneType>();
return isa<Torch::IntType, Torch::FloatType, Torch::NoneType>(
containedType);
}))
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}

View File

@ -179,11 +179,10 @@ public:
"should have concrete Scalar Type.");
}
Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType());
auto impliedTypeFromInputType =
auto impliedTypeFromInputType = cast<BaseTensorType>(
cast<BaseTensorType>(originalResultType)
.getWithSizesAndDtype(originalResultType.getOptionalSizes(),
inputType)
.cast<BaseTensorType>();
inputType));
op.getResult().setType(impliedTypeFromInputType);
return success();

View File

@ -97,11 +97,10 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
}
auto originalResultType = cast<BaseTensorType>(result.getType());
auto impliedTypesFromShape =
auto impliedTypesFromShape = cast<BaseTensorType>(
cast<BaseTensorType>(originalResultType)
.getWithSizesAndDtype(ArrayRef(sizes),
originalResultType.getOptionalDtype())
.cast<BaseTensorType>();
originalResultType.getOptionalDtype()));
return updateCalculateOpResultTypes(op, resultNum, impliedTypesFromShape,
rewriter);

View File

@ -74,7 +74,7 @@ LogicalResult FromBuiltinTensorOp::verify() {
//===----------------------------------------------------------------------===//
OpFoldResult FromI1Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::BoolAttr>();
auto attr = dyn_cast_or_null<mlir::BoolAttr>(adaptor.getOperand());
if (attr) {
return attr;
} else {
@ -87,7 +87,7 @@ OpFoldResult FromI1Op::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult ToI1Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::BoolAttr>();
auto attr = dyn_cast_or_null<mlir::BoolAttr>(adaptor.getOperand());
if (attr) {
return attr;
} else {
@ -100,7 +100,7 @@ OpFoldResult ToI1Op::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>();
auto attr = dyn_cast_or_null<mlir::IntegerAttr>(adaptor.getOperand());
if (attr) {
return attr;
} else {
@ -113,7 +113,7 @@ OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>();
auto attr = dyn_cast_or_null<mlir::IntegerAttr>(adaptor.getOperand());
if (attr) {
return attr;
} else {
@ -126,7 +126,7 @@ OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>();
auto attr = dyn_cast_or_null<mlir::FloatAttr>(adaptor.getOperand());
if (attr) {
return attr;
} else {
@ -139,7 +139,7 @@ OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult FromF64Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>();
auto attr = dyn_cast_or_null<mlir::FloatAttr>(adaptor.getOperand());
if (attr) {
return attr;
} else {

View File

@ -91,7 +91,7 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target,
return std::nullopt;
// Other input type to be converted to i64 are handled by other
// materializers.
if (!inputs[0].getType().isa<Torch::IntType>())
if (!isa<Torch::IntType>(inputs[0].getType()))
return std::nullopt;
assert(inputs.size() == 1);
return builder.create<ToI64Op>(loc, inputs[0]).getResult();
@ -145,7 +145,7 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target,
return std::nullopt;
// Other input type to be converted to i64 are handled by other
// materializers.
if (!inputs[0].getType().isa<Torch::GeneratorType>())
if (!isa<Torch::GeneratorType>(inputs[0].getType()))
return std::nullopt;
assert(inputs.size() == 1);
return builder.create<GeneratorToI64Op>(loc, inputs[0]).getResult();

View File

@ -56,7 +56,7 @@ void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); }
static bool isArgMemRefTypeValid(Type type) {
if (auto memRefType = dyn_cast<MemRefType>(type)) {
Type elemTy = memRefType.getElementType();
if (elemTy.isa<Float16Type, Float32Type, Float64Type>()) {
if (isa<Float16Type, Float32Type, Float64Type>(elemTy)) {
return true;
} else if (auto integerTy = dyn_cast<IntegerType>(elemTy)) {
if (integerTy.isSignlessInteger(64))
@ -70,7 +70,7 @@ static bool isArgMemRefTypeValid(Type type) {
if (integerTy.isSignlessInteger(1))
return true;
} else if (auto complexTy = dyn_cast<ComplexType>(elemTy)) {
return complexTy.getElementType().isa<Float32Type, Float64Type>();
return isa<Float32Type, Float64Type>(complexTy.getElementType());
}
}
return false;