[NFC] Change to *cast instead of .*cast variants (#3405)

Member casts have been deprecated. Changing over a bunch of the member
cast calls to the global templated variants to remove deprecation
warnings.
pull/3407/head
Rob Suderman 2024-05-30 23:45:13 -07:00 committed by GitHub
parent 4e05e2cd1e
commit afca88a058
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
45 changed files with 551 additions and 658 deletions

View File

@ -54,13 +54,6 @@ cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLI
option(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER "Enables the ONNX C importer" OFF) option(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER "Enables the ONNX C importer" OFF)
# TODO(#3299): migrate to from member x.cast<T>() to mlir::cast<T>(x).
if(MSVC)
add_compile_options(/wd4996)
else()
add_compile_options(-Wno-deprecated-declarations)
endif()
macro(torch_mlir_enable_werror) macro(torch_mlir_enable_werror)
if(TORCH_MLIR_ENABLE_WERROR_FLAG) if(TORCH_MLIR_ENABLE_WERROR_FLAG)
if(NOT MSVC) if(NOT MSVC)

View File

@ -125,7 +125,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::copy_if(getInputOperands(), llvm::copy_if(getInputOperands(),
std::back_inserter(result), std::back_inserter(result),
[](OpOperand *opOperand) { [](OpOperand *opOperand) {
return opOperand->get().getType().template isa<MemRefType>(); return isa<MemRefType>(opOperand->get().getType());
}); });
return result; return result;
}] }]
@ -144,7 +144,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::copy_if(getInputOperands(), llvm::copy_if(getInputOperands(),
std::back_inserter(result), std::back_inserter(result),
[](OpOperand *opOperand) { [](OpOperand *opOperand) {
return opOperand->get().getType().template isa<RankedTensorType>(); return isa<RankedTensorType>(opOperand->get().getType());
}); });
return result; return result;
}] }]
@ -200,7 +200,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::copy_if(getOutputOperands(), llvm::copy_if(getOutputOperands(),
std::back_inserter(result), std::back_inserter(result),
[](OpOperand *opOperand) { [](OpOperand *opOperand) {
return opOperand->get().getType().template isa<MemRefType>(); return isa<MemRefType>(opOperand->get().getType());
}); });
return result; return result;
}] }]
@ -219,7 +219,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::copy_if(getOutputOperands(), llvm::copy_if(getOutputOperands(),
std::back_inserter(result), std::back_inserter(result),
[](OpOperand *opOperand) { [](OpOperand *opOperand) {
return opOperand->get().getType().template isa<RankedTensorType>(); return isa<RankedTensorType>(opOperand->get().getType());
}); });
return result; return result;
}] }]
@ -238,7 +238,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::transform(getOutputBufferOperands(), llvm::transform(getOutputBufferOperands(),
std::back_inserter(result), std::back_inserter(result),
[](OpOperand *opOperands) { [](OpOperand *opOperands) {
return opOperands->get().getType().cast<MemRefType>(); return cast<MemRefType>(opOperands->get().getType());
}); });
return result; return result;
}] }]
@ -257,7 +257,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::transform(getOutputTensorOperands(), llvm::transform(getOutputTensorOperands(),
std::back_inserter(result), std::back_inserter(result),
[](OpOperand *opOperands) { [](OpOperand *opOperands) {
return opOperands->get().getType().cast<RankedTensorType>(); return cast<RankedTensorType>(opOperands->get().getType());
}); });
return result; return result;
}] }]
@ -318,7 +318,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
/*args=*/(ins "OpOperand *":$opOperand), /*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"", /*methodBody=*/"",
/*defaultImplementation=*/[{ /*defaultImplementation=*/[{
if (!opOperand->get().getType().template isa<RankedTensorType>()) if (!isa<RankedTensorType>(opOperand->get().getType()))
return false; return false;
if (opOperand->getOperandNumber() < $_op.getNumInputs()) if (opOperand->getOperandNumber() < $_op.getNumInputs())
return true; return true;
@ -334,7 +334,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
/*args=*/(ins "OpOperand *":$opOperand), /*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"", /*methodBody=*/"",
/*defaultImplementation=*/[{ /*defaultImplementation=*/[{
if (!opOperand->get().getType().template isa<RankedTensorType>()) if (!isa<RankedTensorType>(opOperand->get().getType()))
return false; return false;
if (opOperand->getOperandNumber() >= $_op.getNumInputs()) if (opOperand->getOperandNumber() >= $_op.getNumInputs())
return true; return true;
@ -367,7 +367,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
/*defaultImplementation=*/[{ /*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation()); assert(opOperand->getOwner() == this->getOperation());
if (auto shapedType = if (auto shapedType =
opOperand->get().getType().template dyn_cast<ShapedType>()) dyn_cast<ShapedType>(opOperand->get().getType()))
return shapedType.getRank(); return shapedType.getRank();
return 0; return 0;
}] }]
@ -383,7 +383,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
/*defaultImplementation=*/[{ /*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation()); assert(opOperand->getOwner() == this->getOperation());
if (auto shapedType = if (auto shapedType =
opOperand->get().getType().template dyn_cast<ShapedType>()) dyn_cast<ShapedType>(opOperand->get().getType()))
return shapedType.getShape(); return shapedType.getShape();
return {}; return {};
}] }]
@ -398,7 +398,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
/*methodBody=*/"", /*methodBody=*/"",
/*defaultImplementation=*/[{ /*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation()); assert(opOperand->getOwner() == this->getOperation());
return !opOperand->get().getType().template isa<ShapedType>(); return !isa<ShapedType>(opOperand->get().getType());
}] }]
>, >,
//===------------------------------------------------------------------===// //===------------------------------------------------------------------===//
@ -416,10 +416,10 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
return this->getOperation()->getNumResults() == 0 && return this->getOperation()->getNumResults() == 0 &&
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) { llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
return isScalar(opOperand) || return isScalar(opOperand) ||
opOperand->get().getType().template isa<MemRefType>(); isa<MemRefType>(opOperand->get().getType());
}) && }) &&
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) { llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
return opOperand->get().getType().template isa<MemRefType>(); return isa<MemRefType>(opOperand->get().getType());
}); });
}] }]
>, >,
@ -435,10 +435,10 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
return return
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) { llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
return isScalar(opOperand) || return isScalar(opOperand) ||
opOperand->get().getType().template isa<RankedTensorType>(); isa<RankedTensorType>(opOperand->get().getType());
}) && }) &&
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) { llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
return opOperand->get().getType().template isa<RankedTensorType>(); return isa<RankedTensorType>(opOperand->get().getType());
}); });
}] }]
>, >,
@ -478,8 +478,8 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
private: private:
void setOperandSegmentAt(unsigned idx, unsigned val) { void setOperandSegmentAt(unsigned idx, unsigned val) {
auto attr = (*this)->getAttr("operand_segment_sizes") auto attr = cast<DenseIntElementsAttr>((*this)->getAttr("operand_segment_sizes")
.cast<DenseIntElementsAttr>(); );
unsigned i = 0; unsigned i = 0;
auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32), auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32),
[&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; }); [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });

View File

@ -88,7 +88,7 @@ def TMTensor_ScanOp : TMTensor_Op<"scan",
return getOutputOperand(0)->get(); return getOutputOperand(0)->get();
} }
ShapedType getOperandType() { ShapedType getOperandType() {
return input().getType().cast<ShapedType>(); return cast<ShapedType>(input().getType());
} }
int64_t getOperandRank() { int64_t getOperandRank() {
return getOperandType().getRank(); return getOperandType().getRank();
@ -151,10 +151,10 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{ let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
int64_t getIndexDepth() { int64_t getIndexDepth() {
return getInputOperand(1) return cast<ShapedType>(getInputOperand(1)
->get() ->get()
.getType() .getType()
.cast<ShapedType>() )
.getShape() .getShape()
.back(); .back();
} }
@ -164,7 +164,7 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
} }
ShapedType getUpdateType() { ShapedType getUpdateType() {
return updates().getType().cast<ShapedType>(); return cast<ShapedType>(updates().getType());
} }
Value indices() { Value indices() {
@ -172,7 +172,7 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
} }
ShapedType getIndicesType() { ShapedType getIndicesType() {
return indices().getType().cast<ShapedType>(); return cast<ShapedType>(indices().getType());
} }
Value original() { Value original() {
@ -180,11 +180,11 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
} }
ShapedType getOriginalType() { ShapedType getOriginalType() {
return original().getType().cast<ShapedType>(); return cast<ShapedType>(original().getType());
} }
int64_t getUpdateSliceRank() { int64_t getUpdateSliceRank() {
return updates().getType().cast<ShapedType>().getRank() - 1; return cast<ShapedType>(updates().getType()).getRank() - 1;
} }
bool isScalarUpdate() { bool isScalarUpdate() {
@ -224,7 +224,7 @@ def TMTensor_SortOp : TMTensor_Op<"sort",
return getOutputs()[index]; return getOutputs()[index];
} }
ShapedType getOperandType(int index) { ShapedType getOperandType(int index) {
return operand(index).getType().cast<ShapedType>(); return cast<ShapedType>(operand(index).getType());
} }
int64_t getOperandRank() { int64_t getOperandRank() {
return getOperandType(0).getRank(); return getOperandType(0).getRank();
@ -291,16 +291,16 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
return getOutputOperand(0)->get(); return getOutputOperand(0)->get();
} }
ShapedType getQueryType() { ShapedType getQueryType() {
return getQuery().getType().cast<ShapedType>(); return cast<ShapedType>(getQuery().getType());
} }
ShapedType getKeyType() { ShapedType getKeyType() {
return getKey().getType().cast<ShapedType>(); return cast<ShapedType>(getKey().getType());
} }
ShapedType getValueType() { ShapedType getValueType() {
return getValue().getType().cast<ShapedType>(); return cast<ShapedType>(getValue().getType());
} }
ShapedType getOutputType() { ShapedType getOutputType() {
return getOutput().getType().cast<ShapedType>(); return cast<ShapedType>(getOutput().getType());
} }
int64_t getQueryRank() { int64_t getQueryRank() {
return getQueryType().getRank(); return getQueryType().getRank();

View File

@ -61,12 +61,12 @@ struct onnx_list_of_constant_ints_op_binder {
bool match(Operation *op) { bool match(Operation *op) {
auto constOp = dyn_cast<Torch::OperatorOp>(op); auto constOp = dyn_cast<Torch::OperatorOp>(op);
if (!constOp || !constOp.getName().equals("onnx.Constant")) if (!constOp || !(constOp.getName() == "onnx.Constant"))
return false; return false;
if (DenseResourceElementsAttr attr = if (DenseResourceElementsAttr attr =
constOp->getAttr("torch.onnx.value") dyn_cast_or_null<DenseResourceElementsAttr>(
.dyn_cast_or_null<DenseResourceElementsAttr>()) { constOp->getAttr("torch.onnx.value"))) {
// Bytes are stored in little endian order. Big endian support will // Bytes are stored in little endian order. Big endian support will
// require swizzling. // require swizzling.
if (!Endian::little) { if (!Endian::little) {

View File

@ -190,7 +190,7 @@ struct torch_list_of_optional_constant_ints_op_binder {
int64_t num; int64_t num;
if (matchPattern(value, m_TorchConstantInt(&num))) if (matchPattern(value, m_TorchConstantInt(&num)))
bind_values.push_back(num); bind_values.push_back(num);
else if (value.getType().isa<Torch::NoneType>()) else if (isa<Torch::NoneType>(value.getType()))
bind_values.push_back(std::nullopt); bind_values.push_back(std::nullopt);
else else
return false; return false;

View File

@ -442,8 +442,8 @@ def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [
}]; }];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
Type getKeyType() { return getType().cast<DictType>().getKeyType(); } Type getKeyType() { return cast<DictType>(getType()).getKeyType(); }
Type getValueType() { return getType().cast<DictType>().getValueType(); } Type getValueType() { return cast<DictType>(getType()).getValueType(); }
}]; }];
} }
@ -1003,7 +1003,7 @@ def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [
DeclareOpInterfaceMethods<InferTypeOpInterface>, DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"operand is corresponding !torch.vtensor", TypesMatchWith<"operand is corresponding !torch.vtensor",
"result", "operand", "result", "operand",
"$_self.cast<NonValueTensorType>().getWithValueSemantics()">, "cast<NonValueTensorType>($_self).getWithValueSemantics()">,
]> { ]> {
let summary = "Create a !torch.tensor with the same contents as the operand"; let summary = "Create a !torch.tensor with the same contents as the operand";
let description = [{ let description = [{
@ -1036,7 +1036,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
DeclareOpInterfaceMethods<InferTypeOpInterface>, DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"operand is corresponding !torch.tensor", TypesMatchWith<"operand is corresponding !torch.tensor",
"result", "operand", "result", "operand",
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()">, "cast<ValueTensorType>($_self).getWithoutValueSemantics()">,
]> { ]> {
let summary = "Create a !torch.vtensor with the same contents as the operand"; let summary = "Create a !torch.vtensor with the same contents as the operand";
let description = [{ let description = [{
@ -1064,7 +1064,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [ def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [
TypesMatchWith<"overwritten tensor type is corresponding !torch.tensor of value tensor type", TypesMatchWith<"overwritten tensor type is corresponding !torch.tensor of value tensor type",
"value", "overwritten", "value", "overwritten",
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()"> "cast<ValueTensorType>($_self).getWithoutValueSemantics()">
]> { ]> {
let summary = "Ovewrite the contents of tensor with values from another."; let summary = "Ovewrite the contents of tensor with values from another.";
let description = [{ let description = [{

View File

@ -199,7 +199,7 @@ def Torch_ValueTensorType : AnyTorchTensorType<"ValueTensor", "vtensor"> {
} }
def AnyTorchTensorType : Type< def AnyTorchTensorType : Type<
CPred<"$_self.isa<::mlir::torch::Torch::BaseTensorType>()">, CPred<"isa<::mlir::torch::Torch::BaseTensorType>($_self)">,
"Any Torch tensor type" "Any Torch tensor type"
>; >;
@ -410,11 +410,11 @@ def AnyTorchOptionalDeviceType:
def AnyTorchOptionalGeneratorType: def AnyTorchOptionalGeneratorType:
OptionalOf<Torch_GeneratorType, "Optional torch Generator type">; OptionalOf<Torch_GeneratorType, "Optional torch Generator type">;
def IsListTypePred : CPred<"$_self.isa<::mlir::torch::Torch::ListType>()">; def IsListTypePred : CPred<"isa<::mlir::torch::Torch::ListType>($_self)">;
class ListOf<list<Type> allowedTypes, string descr> : class ListOf<list<Type> allowedTypes, string descr> :
ContainerType<AnyTypeOf<allowedTypes>, ContainerType<AnyTypeOf<allowedTypes>,
IsListTypePred, IsListTypePred,
"$_self.cast<::mlir::torch::Torch::ListType>().getContainedType()", "cast<::mlir::torch::Torch::ListType>($_self).getContainedType()",
descr, "::mlir::torch::Torch::ListType">; descr, "::mlir::torch::Torch::ListType">;
def AnyTorchListOfTorchBoolType : ListOf<[Torch_BoolType], "Bool list type (bool[])">; def AnyTorchListOfTorchBoolType : ListOf<[Torch_BoolType], "Bool list type (bool[])">;

View File

@ -26,7 +26,7 @@ bool torchMlirTypeIsValidSubtype(MlirType subtype, MlirType type) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchNnModule(MlirType t) { bool torchMlirTypeIsATorchNnModule(MlirType t) {
return unwrap(t).isa<Torch::NnModuleType>(); return isa<Torch::NnModuleType>(unwrap(t));
} }
MlirType torchMlirTorchNnModuleTypeGet(MlirContext context, MlirType torchMlirTorchNnModuleTypeGet(MlirContext context,
@ -43,7 +43,7 @@ MlirTypeID torchMlirTorchNnModuleTypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchOptional(MlirType t) { bool torchMlirTypeIsATorchOptional(MlirType t) {
return unwrap(t).isa<Torch::OptionalType>(); return isa<Torch::OptionalType>(unwrap(t));
} }
MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) { MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) {
@ -64,7 +64,7 @@ MlirTypeID torchMlirTorchOptionalTypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchTuple(MlirType t) { bool torchMlirTypeIsATorchTuple(MlirType t) {
return unwrap(t).isa<Torch::TupleType>(); return isa<Torch::TupleType>(unwrap(t));
} }
MlirType torchMlirTorchTupleTypeGet(MlirContext context, MlirType torchMlirTorchTupleTypeGet(MlirContext context,
@ -95,7 +95,7 @@ MlirTypeID torchMlirTorchTupleTypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchUnion(MlirType t) { bool torchMlirTypeIsATorchUnion(MlirType t) {
return unwrap(t).isa<Torch::UnionType>(); return isa<Torch::UnionType>(unwrap(t));
} }
MlirType torchMlirTorchUnionTypeGet(MlirContext context, MlirType torchMlirTorchUnionTypeGet(MlirContext context,
@ -126,7 +126,7 @@ MlirTypeID torchMlirTorchUnionTypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchList(MlirType t) { bool torchMlirTypeIsATorchList(MlirType t) {
return unwrap(t).isa<Torch::ListType>(); return isa<Torch::ListType>(unwrap(t));
} }
MlirType torchMlirTorchListTypeGet(MlirType containedType) { MlirType torchMlirTorchListTypeGet(MlirType containedType) {
@ -146,7 +146,7 @@ MlirTypeID torchMlirTorchListTypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchDevice(MlirType t) { bool torchMlirTypeIsATorchDevice(MlirType t) {
return unwrap(t).isa<Torch::DeviceType>(); return isa<Torch::DeviceType>(unwrap(t));
} }
MlirType torchMlirTorchDeviceTypeGet(MlirContext context) { MlirType torchMlirTorchDeviceTypeGet(MlirContext context) {
@ -162,7 +162,7 @@ MlirTypeID torchMlirTorchDeviceTypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchGenerator(MlirType t) { bool torchMlirTypeIsATorchGenerator(MlirType t) {
return unwrap(t).isa<Torch::GeneratorType>(); return isa<Torch::GeneratorType>(unwrap(t));
} }
MlirType torchMlirTorchGeneratorTypeGet(MlirContext context) { MlirType torchMlirTorchGeneratorTypeGet(MlirContext context) {
@ -178,7 +178,7 @@ MlirTypeID torchMlirTorchGeneratorTypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchBool(MlirType t) { bool torchMlirTypeIsATorchBool(MlirType t) {
return unwrap(t).isa<Torch::BoolType>(); return isa<Torch::BoolType>(unwrap(t));
} }
MlirType torchMlirTorchBoolTypeGet(MlirContext context) { MlirType torchMlirTorchBoolTypeGet(MlirContext context) {
@ -194,7 +194,7 @@ MlirTypeID torchMlirTorchBoolTypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchInt(MlirType t) { bool torchMlirTypeIsATorchInt(MlirType t) {
return unwrap(t).isa<Torch::IntType>(); return isa<Torch::IntType>(unwrap(t));
} }
MlirType torchMlirTorchIntTypeGet(MlirContext context) { MlirType torchMlirTorchIntTypeGet(MlirContext context) {
@ -210,7 +210,7 @@ MlirTypeID torchMlirTorchIntTypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchFloat(MlirType t) { bool torchMlirTypeIsATorchFloat(MlirType t) {
return unwrap(t).isa<Torch::FloatType>(); return isa<Torch::FloatType>(unwrap(t));
} }
MlirType torchMlirTorchFloatTypeGet(MlirContext context) { MlirType torchMlirTorchFloatTypeGet(MlirContext context) {
@ -226,7 +226,7 @@ MlirTypeID torchMlirTorchFloatTypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchLinearParams(MlirType t) { bool torchMlirTypeIsATorchLinearParams(MlirType t) {
return unwrap(t).isa<Torch::LinearParamsType>(); return isa<Torch::LinearParamsType>(unwrap(t));
} }
MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context) { MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context) {
@ -242,7 +242,7 @@ MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchQInt8(MlirType t) { bool torchMlirTypeIsATorchQInt8(MlirType t) {
return unwrap(t).isa<Torch::QInt8Type>(); return isa<Torch::QInt8Type>(unwrap(t));
} }
MlirType torchMlirTorchQInt8TypeGet(MlirContext context) { MlirType torchMlirTorchQInt8TypeGet(MlirContext context) {
@ -258,7 +258,7 @@ MlirTypeID torchMlirTorchQInt8TypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchQUInt8(MlirType t) { bool torchMlirTypeIsATorchQUInt8(MlirType t) {
return unwrap(t).isa<Torch::QUInt8Type>(); return isa<Torch::QUInt8Type>(unwrap(t));
} }
MlirType torchMlirTorchQUInt8TypeGet(MlirContext context) { MlirType torchMlirTorchQUInt8TypeGet(MlirContext context) {
@ -274,7 +274,7 @@ MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchNonValueTensor(MlirType t) { bool torchMlirTypeIsATorchNonValueTensor(MlirType t) {
return unwrap(t).isa<Torch::NonValueTensorType>(); return isa<Torch::NonValueTensorType>(unwrap(t));
} }
MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context, MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
@ -341,7 +341,7 @@ MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchValueTensor(MlirType t) { bool torchMlirTypeIsATorchValueTensor(MlirType t) {
return unwrap(t).isa<Torch::ValueTensorType>(); return isa<Torch::ValueTensorType>(unwrap(t));
} }
MlirType torchMlirTorchValueTensorTypeGet(MlirContext context, MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
@ -408,7 +408,7 @@ MlirTypeID torchMlirTorchValueTensorTypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchNone(MlirType t) { bool torchMlirTypeIsATorchNone(MlirType t) {
return unwrap(t).isa<Torch::NoneType>(); return isa<Torch::NoneType>(unwrap(t));
} }
MlirType torchMlirTorchNoneTypeGet(MlirContext context) { MlirType torchMlirTorchNoneTypeGet(MlirContext context) {
@ -424,7 +424,7 @@ MlirTypeID torchMlirTorchNoneTypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchString(MlirType t) { bool torchMlirTypeIsATorchString(MlirType t) {
return unwrap(t).isa<Torch::StringType>(); return isa<Torch::StringType>(unwrap(t));
} }
MlirType torchMlirTorchStringTypeGet(MlirContext context) { MlirType torchMlirTorchStringTypeGet(MlirContext context) {
@ -440,7 +440,7 @@ MlirTypeID torchMlirTorchStringTypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchAny(MlirType t) { bool torchMlirTypeIsATorchAny(MlirType t) {
return unwrap(t).isa<Torch::AnyType>(); return isa<Torch::AnyType>(unwrap(t));
} }
MlirType torchMlirTorchAnyTypeGet(MlirContext context) { MlirType torchMlirTorchAnyTypeGet(MlirContext context) {
@ -456,7 +456,7 @@ MlirTypeID torchMlirTorchAnyTypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchNumber(MlirType t) { bool torchMlirTypeIsATorchNumber(MlirType t) {
return unwrap(t).isa<Torch::NumberType>(); return isa<Torch::NumberType>(unwrap(t));
} }
MlirType torchMlirTorchNumberTypeGet(MlirContext context) { MlirType torchMlirTorchNumberTypeGet(MlirContext context) {
@ -472,7 +472,7 @@ MlirTypeID torchMlirTorchNumberTypeGetTypeID() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchDict(MlirType t) { bool torchMlirTypeIsATorchDict(MlirType t) {
return unwrap(t).isa<Torch::DictType>(); return isa<Torch::DictType>(unwrap(t));
} }
MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) { MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) {

View File

@ -546,12 +546,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value shuffledPaddingList = Value shuffledPaddingList =
createConstantIntList(binder, rewriter, padding); createConstantIntList(binder, rewriter, padding);
Value zero; Value zero;
if (resultTypeOut.getDtype().isa<FloatType>()) { if (isa<FloatType>(resultTypeOut.getDtype())) {
zero = rewriter.create<Torch::ConstantFloatOp>( zero = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr( rewriter.getF64FloatAttr(
std::numeric_limits<double>::lowest())); std::numeric_limits<double>::lowest()));
} else if (resultTypeOut.getDtype().isa<IntegerType>()) { } else if (isa<IntegerType>(resultTypeOut.getDtype())) {
zero = rewriter.create<Torch::ConstantIntOp>( zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr( binder.getLoc(), rewriter.getI64IntegerAttr(
std::numeric_limits<int64_t>::lowest())); std::numeric_limits<int64_t>::lowest()));
@ -1295,7 +1295,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.tensorResultType(resultType)) binder.tensorResultType(resultType))
return failure(); return failure();
auto inputTensorType = operand.getType().cast<Torch::ValueTensorType>(); auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
if (!inputTensorType || !inputTensorType.hasSizes()) { if (!inputTensorType || !inputTensorType.hasSizes()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "Expected input type having sizes"); binder.op, "Expected input type having sizes");
@ -1509,10 +1509,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
if (!constantValue) { if (!constantValue) {
auto dataTensorType = cast<Torch::ValueTensorType>(data.getType()); auto dataTensorType = cast<Torch::ValueTensorType>(data.getType());
if (dataTensorType.getDtype().isa<IntegerType>()) if (isa<IntegerType>(dataTensorType.getDtype()))
constantValue = rewriter.create<Torch::ConstantIntOp>( constantValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0)); loc, rewriter.getI64IntegerAttr(0));
if (dataTensorType.getDtype().isa<FloatType>()) if (isa<FloatType>(dataTensorType.getDtype()))
constantValue = rewriter.create<Torch::ConstantFloatOp>( constantValue = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(0.0f)); loc, rewriter.getF64FloatAttr(0.0f));

View File

@ -1023,9 +1023,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc()); Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value constFalse = Value constFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false); rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
auto size = data.getType() auto size =
.dyn_cast<Torch::ValueTensorType>() dyn_cast<Torch::ValueTensorType>(data.getType()).getOptionalSizes();
.getOptionalSizes();
auto f64ResultType = rewriter.getType<Torch::ValueTensorType>( auto f64ResultType = rewriter.getType<Torch::ValueTensorType>(
size, rewriter.getF64Type()); size, rewriter.getF64Type());
Value dataCast = rewriter.create<Torch::AtenToDtypeOp>( Value dataCast = rewriter.create<Torch::AtenToDtypeOp>(
@ -2906,8 +2905,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
scalesValueList = noneVal; scalesValueList = noneVal;
sizesValueList = getValueList(sizeOperand); sizesValueList = getValueList(sizeOperand);
} }
if (scalesValueList.getType().isa<Torch::NoneType>() && if (isa<Torch::NoneType>(scalesValueList.getType()) &&
sizesValueList.getType().isa<Torch::NoneType>()) { isa<Torch::NoneType>(sizesValueList.getType())) {
return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode"); return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode");
} }
rewriter rewriter

View File

@ -1868,9 +1868,8 @@ public:
const TypeConverter *typeConverter = getTypeConverter(); const TypeConverter *typeConverter = getTypeConverter();
auto input = adaptor.getSelf(); auto input = adaptor.getSelf();
RankedTensorType resultType = RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()) typeConverter->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
SmallVector<Value> resultShape; SmallVector<Value> resultShape;
SmallVector<Value> offsets; SmallVector<Value> offsets;
@ -2107,9 +2106,8 @@ public:
auto input = adaptor.getSelf(); auto input = adaptor.getSelf();
RankedTensorType resultType = RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()) typeConverter->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
SmallVector<Value> resultShape; SmallVector<Value> resultShape;
SmallVector<Value> offsets; SmallVector<Value> offsets;
@ -2343,9 +2341,8 @@ public:
op, "diagonal dimensions cannot be identical"); op, "diagonal dimensions cannot be identical");
Type elementType = inputType.getElementType(); Type elementType = inputType.getElementType();
RankedTensorType outputType = getTypeConverter() RankedTensorType outputType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
Location loc = op.getLoc(); Location loc = op.getLoc();
Value dim1Size, dim2Size; Value dim1Size, dim2Size;
@ -2581,9 +2578,8 @@ public:
}) })
.getResult(0); .getResult(0);
RankedTensorType resultType = getTypeConverter() RankedTensorType resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, resultTensor); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, resultTensor);
return success(); return success();
@ -2608,9 +2604,8 @@ public:
return failure(); return failure();
// Conversion is completed specified by information in the sparse tensor // Conversion is completed specified by information in the sparse tensor
// type. Thus, we can rewrite all legalizedNames to the same construct. // type. Thus, we can rewrite all legalizedNames to the same construct.
RankedTensorType resultType = getTypeConverter() RankedTensorType resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>( rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>(
op, resultType, adaptor.getOperands()[0]); op, resultType, adaptor.getOperands()[0]);
return success(); return success();

View File

@ -845,7 +845,7 @@ public:
outputSizeIntValues = getTypeConvertedValues( outputSizeIntValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), outputSizeTorchInt); rewriter, loc, getTypeConverter(), outputSizeTorchInt);
if (!op.getScalesH().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getScalesH().getType())) {
// Convert float values to int values. // Convert float values to int values.
// int_value = (int64_t)ceil(float_value) // int_value = (int64_t)ceil(float_value)
Value ceilVal = rewriter.create<math::CeilOp>(loc, adaptor.getScalesH()); Value ceilVal = rewriter.create<math::CeilOp>(loc, adaptor.getScalesH());
@ -858,7 +858,7 @@ public:
scaleFactorsInt.push_back(scaleFactorVal); scaleFactorsInt.push_back(scaleFactorVal);
} }
if (!op.getScalesW().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getScalesW().getType())) {
// Convert float values to int values. // Convert float values to int values.
// int_value = (int64_t)ceil(float_value) // int_value = (int64_t)ceil(float_value)
Value ceilVal = rewriter.create<math::CeilOp>(loc, adaptor.getScalesW()); Value ceilVal = rewriter.create<math::CeilOp>(loc, adaptor.getScalesW());
@ -1006,7 +1006,7 @@ public:
unsigned hDimOffset = 2; unsigned hDimOffset = 2;
SmallVector<Value, 2> scaleFactorsFloatValues; SmallVector<Value, 2> scaleFactorsFloatValues;
if (!op.getScalesH().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getScalesH().getType())) {
scaleFactorsFloatValues.push_back(adaptor.getScalesH()); scaleFactorsFloatValues.push_back(adaptor.getScalesH());
} else { } else {
auto scaleFactorVal = rewriter.create<arith::DivFOp>( auto scaleFactorVal = rewriter.create<arith::DivFOp>(
@ -1019,7 +1019,7 @@ public:
scaleFactorsFloatValues.push_back(scaleFactorVal); scaleFactorsFloatValues.push_back(scaleFactorVal);
} }
if (!op.getScalesW().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getScalesW().getType())) {
scaleFactorsFloatValues.push_back(adaptor.getScalesW()); scaleFactorsFloatValues.push_back(adaptor.getScalesW());
} else { } else {
auto scaleFactorVal = rewriter.create<arith::DivFOp>( auto scaleFactorVal = rewriter.create<arith::DivFOp>(

View File

@ -41,7 +41,7 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
return; return;
int64_t minSI = -(1 << (numBits - 1)); int64_t minSI = -(1 << (numBits - 1));
Value minSIValue = rewriter.create<arith::ConstantIntOp>( Value minSIValue = rewriter.create<arith::ConstantIntOp>(
loc, minSI, zp.getType().cast<mlir::IntegerType>().getWidth()); loc, minSI, cast<mlir::IntegerType>(zp.getType()).getWidth());
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue); zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits); minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
arg = torch_to_linalg::createElementwiseLinalgGeneric( arg = torch_to_linalg::createElementwiseLinalgGeneric(
@ -1057,10 +1057,10 @@ public:
loc, getAsOpFoldResult(outDims), accumulatorDType); loc, getAsOpFoldResult(outDims), accumulatorDType);
Value outputTensor; Value outputTensor;
if (accumulatorDType != resultDTy && !bias.getType().isa<Torch::NoneType>()) if (accumulatorDType != resultDTy && !isa<Torch::NoneType>(bias.getType()))
bias = torch_to_linalg::convertTensorToElementType(rewriter, loc, bias, bias = torch_to_linalg::convertTensorToElementType(rewriter, loc, bias,
accumulatorDType); accumulatorDType);
if (bias.getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(bias.getType())) {
Value c0; Value c0;
if (isa<mlir::FloatType>(accumulatorDType)) { if (isa<mlir::FloatType>(accumulatorDType)) {
c0 = rewriter.create<arith::ConstantOp>( c0 = rewriter.create<arith::ConstantOp>(

View File

@ -409,10 +409,8 @@ public:
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
RankedTensorType selfType = cast<RankedTensorType>(self.getType()); RankedTensorType selfType = cast<RankedTensorType>(self.getType());
Type elementType = selfType.getElementType(); Type elementType = selfType.getElementType();
RankedTensorType indicesRankedTensorType = RankedTensorType indicesRankedTensorType = cast<RankedTensorType>(
getTypeConverter() getTypeConverter()->convertType(op->getResult(1).getType()));
->convertType(op->getResult(1).getType())
.cast<RankedTensorType>();
// TODO: Add support for 3D inputs. // TODO: Add support for 3D inputs.
if (selfType.getRank() == 3) if (selfType.getRank() == 3)
@ -717,10 +715,10 @@ public:
Location loc = op->getLoc(); Location loc = op->getLoc();
const TypeConverter *typeConverter = opConversionPattern.getTypeConverter(); const TypeConverter *typeConverter = opConversionPattern.getTypeConverter();
outputType = typeConverter->convertType(op.getResult0().getType()) outputType = cast<RankedTensorType>(
.template cast<RankedTensorType>(); typeConverter->convertType(op.getResult0().getType()));
auxTensorType = typeConverter->convertType(op.getResult1().getType()) auxTensorType = cast<RankedTensorType>(
.template cast<RankedTensorType>(); typeConverter->convertType(op.getResult1().getType()));
Type auxTensorElementType = auxTensorType.getElementType(); Type auxTensorElementType = auxTensorType.getElementType();
auto smallestFPValueAttr = rewriter.getFloatAttr( auto smallestFPValueAttr = rewriter.getFloatAttr(
elementType, elementType,
@ -799,8 +797,8 @@ public:
Location loc = op->getLoc(); Location loc = op->getLoc();
const TypeConverter *typeConverter = opConversionPattern.getTypeConverter(); const TypeConverter *typeConverter = opConversionPattern.getTypeConverter();
outputType = typeConverter->convertType(op.getResult().getType()) outputType = cast<RankedTensorType>(
.template cast<RankedTensorType>(); typeConverter->convertType(op.getResult().getType()));
buffVal = rewriter.create<arith::ConstantOp>( buffVal = rewriter.create<arith::ConstantOp>(
loc, elementType, rewriter.getFloatAttr(elementType, 0)); loc, elementType, rewriter.getFloatAttr(elementType, 0));
auxTensor = rewriter.create<tensor::EmptyOp>( auxTensor = rewriter.create<tensor::EmptyOp>(

View File

@ -42,9 +42,8 @@ public:
if (train) if (train)
return failure(); return failure();
auto resultType = getTypeConverter() auto resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
adaptor.getInput()); adaptor.getInput());
return success(); return success();
@ -60,8 +59,8 @@ static Value toLinearIndex(OpBuilder &b, Location loc,
Value result = Value result =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(b.getI64Type())); b.create<arith::ConstantOp>(loc, b.getZeroAttr(b.getI64Type()));
for (auto [index, stride] : llvm::zip(indicesIntValues, shapeIntValues)) { for (auto [index, stride] : llvm::zip(indicesIntValues, shapeIntValues)) {
assert(index.getType().isa<mlir::IntegerType>() && assert(isa<mlir::IntegerType>(index.getType()) &&
stride.getType().isa<mlir::IntegerType>() && isa<mlir::IntegerType>(stride.getType()) &&
"Input arrays to `toLinearIndex` must only contain values of type " "Input arrays to `toLinearIndex` must only contain values of type "
"`mlir::IntegerType`"); "`mlir::IntegerType`");
Value mul = b.create<arith::MulIOp>(loc, result, stride); Value mul = b.create<arith::MulIOp>(loc, result, stride);
@ -129,7 +128,7 @@ public:
if (!isa<mlir::FloatType>(elemTy)) if (!isa<mlir::FloatType>(elemTy))
return rewriter.notifyMatchFailure(op, "This op only support float type"); return rewriter.notifyMatchFailure(op, "This op only support float type");
if (!generator.getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(generator.getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default " op, "The generator has to be None because only global default "
"generator is supported"); "generator is supported");
@ -180,7 +179,7 @@ public:
b.create<arith::MulFOp>(loc, updateFloat, scale); b.create<arith::MulFOp>(loc, updateFloat, scale);
Value res = b.create<arith::AddFOp>(loc, updateScaled, min); Value res = b.create<arith::AddFOp>(loc, updateScaled, min);
Value truncRes = res; Value truncRes = res;
if (elemTy.isa<Float16Type, Float32Type>()) if (isa<Float16Type, Float32Type>(elemTy))
truncRes = b.create<arith::TruncFOp>(loc, elemTy, res); truncRes = b.create<arith::TruncFOp>(loc, elemTy, res);
b.create<linalg::YieldOp>(loc, truncRes); b.create<linalg::YieldOp>(loc, truncRes);
}) })

View File

@ -86,11 +86,8 @@ public:
bool isUnsigned = false; bool isUnsigned = false;
if (!isa<mlir::FloatType>(inElementType)) { if (!isa<mlir::FloatType>(inElementType)) {
if (isa<mlir::IntegerType>(inElementType)) { if (isa<mlir::IntegerType>(inElementType)) {
auto integerTy = op.getSelf() auto integerTy = dyn_cast<mlir::IntegerType>(
.getType() cast<BaseTensorType>(op.getSelf().getType()).getDtype());
.template cast<BaseTensorType>()
.getDtype()
.template dyn_cast<mlir::IntegerType>();
isUnsigned = integerTy.isUnsigned(); isUnsigned = integerTy.isUnsigned();
} else { } else {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -280,7 +277,7 @@ public:
static Value createAbsOpForNormOps(OpBuilder &b, Location loc, Value elem, static Value createAbsOpForNormOps(OpBuilder &b, Location loc, Value elem,
Type resultElementType) { Type resultElementType) {
if (elem.getType().isa<mlir::ComplexType>()) { if (isa<mlir::ComplexType>(elem.getType())) {
return b.create<complex::AbsOp>(loc, elem); return b.create<complex::AbsOp>(loc, elem);
} }
@ -376,11 +373,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
if (isa<mlir::FloatType>(resultElementType)) if (isa<mlir::FloatType>(resultElementType))
return b.create<arith::MaximumFOp>(loc, self, result); return b.create<arith::MaximumFOp>(loc, self, result);
else if (isa<mlir::IntegerType>(resultElementType)) { else if (isa<mlir::IntegerType>(resultElementType)) {
IntegerType intType = max.getSelf() IntegerType intType = dyn_cast<mlir::IntegerType>(
.getType() cast<BaseTensorType>(max.getSelf().getType()).getDtype());
.cast<BaseTensorType>()
.getDtype()
.dyn_cast<mlir::IntegerType>();
if (intType.isUnsigned()) if (intType.isUnsigned())
return b.create<arith::MaxUIOp>(loc, self, result); return b.create<arith::MaxUIOp>(loc, self, result);
if (intType.isSigned()) if (intType.isSigned())
@ -393,11 +387,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
if (isa<mlir::FloatType>(resultElementType)) if (isa<mlir::FloatType>(resultElementType))
return b.create<arith::MinimumFOp>(loc, self, result); return b.create<arith::MinimumFOp>(loc, self, result);
else if (isa<mlir::IntegerType>(resultElementType)) { else if (isa<mlir::IntegerType>(resultElementType)) {
IntegerType intType = min.getSelf() IntegerType intType = dyn_cast<mlir::IntegerType>(
.getType() cast<BaseTensorType>(min.getSelf().getType()).getDtype());
.cast<BaseTensorType>()
.getDtype()
.dyn_cast<mlir::IntegerType>();
if (intType.isUnsigned()) if (intType.isUnsigned())
return b.create<arith::MinUIOp>(loc, self, result); return b.create<arith::MinUIOp>(loc, self, result);
if (intType.isSigned()) if (intType.isSigned())
@ -657,9 +648,8 @@ public:
return opInfo; return opInfo;
Location loc = op->getLoc(); Location loc = op->getLoc();
auto resultType = getTypeConverter() auto resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
Type elemType = resultType.getElementType(); Type elemType = resultType.getElementType();
LogicalResult elemTypeCheck = LogicalResult elemTypeCheck =
validateReductionElementType(op, elemType, rewriter); validateReductionElementType(op, elemType, rewriter);

View File

@ -179,15 +179,13 @@ public:
for (auto i : {TOP, VCENTER, BOTTOM}) { for (auto i : {TOP, VCENTER, BOTTOM}) {
for (auto j : {LEFT, HCENTER, RIGHT}) { for (auto j : {LEFT, HCENTER, RIGHT}) {
auto constVtile{ auto constVtile{dyn_cast_or_null<mlir::IntegerAttr>(
mlir::dyn_cast<mlir::arith::ConstantOp>(vTile[i].getDefiningOp()) mlir::dyn_cast<mlir::arith::ConstantOp>(vTile[i].getDefiningOp())
.getValue() .getValue())};
.dyn_cast_or_null<mlir::IntegerAttr>()};
auto constHtile{ auto constHtile{dyn_cast_or_null<mlir::IntegerAttr>(
mlir::dyn_cast<mlir::arith::ConstantOp>(hTile[j].getDefiningOp()) mlir::dyn_cast<mlir::arith::ConstantOp>(hTile[j].getDefiningOp())
.getValue() .getValue())};
.dyn_cast_or_null<mlir::IntegerAttr>()};
auto vSize = constVtile.getInt(); auto vSize = constVtile.getInt();
auto hSize = constHtile.getInt(); auto hSize = constHtile.getInt();
@ -369,8 +367,8 @@ public:
for (auto size : resultSize) for (auto size : resultSize)
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size)); resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
auto resultType = typeConverter->convertType(op.getType()) auto resultType =
.template cast<RankedTensorType>(); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
Type resultElementType; Type resultElementType;
if (isa<Torch::NoneType>(op.getDtype().getType())) { if (isa<Torch::NoneType>(op.getDtype().getType())) {
resultElementType = resultType.getElementType(); resultElementType = resultType.getElementType();
@ -426,7 +424,7 @@ public:
op, "unimplemented: pin_memory must be either None or false"); op, "unimplemented: pin_memory must be either None or false");
// Only `none`, `contiguous` and `preserve` memory_format is supported. // Only `none`, `contiguous` and `preserve` memory_format is supported.
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getMemoryFormat().getType())) {
int64_t memoryFormat; int64_t memoryFormat;
if (!matchPattern(op.getMemoryFormat(), if (!matchPattern(op.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat))) m_TorchConstantInt(&memoryFormat)))
@ -441,7 +439,7 @@ public:
} }
// TODO: Add support for device arg other than cpu. // TODO: Add support for device arg other than cpu.
if (!op.getDevice().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getDevice().getType())) {
std::string device; std::string device;
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -453,7 +451,7 @@ public:
// TODO: Add support for non-strided layout. // TODO: Add support for non-strided layout.
// torch.layout is by default strided i.e. 0. // torch.layout is by default strided i.e. 0.
if (!op.getLayout().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getLayout().getType())) {
int64_t tensorLayout; int64_t tensorLayout;
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -478,7 +476,7 @@ public:
auto resultType = auto resultType =
cast<RankedTensorType>(typeConverter->convertType(op.getType())); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
Type resultElementType; Type resultElementType;
if (op.getDtype().getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(op.getDtype().getType())) {
resultElementType = getDefaultDtypeForTorchScalar( resultElementType = getDefaultDtypeForTorchScalar(
Torch::FloatType::get(op->getContext())); Torch::FloatType::get(op->getContext()));
} else { } else {
@ -527,7 +525,7 @@ public:
// The pin_memory should be either `False` or `none`. // The pin_memory should be either `False` or `none`.
bool pinMemory; bool pinMemory;
if (!op.getPinMemory().getType().isa<Torch::NoneType>() && if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
pinMemory)) { pinMemory)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -536,9 +534,8 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
const TypeConverter *typeConverter = this->getTypeConverter(); const TypeConverter *typeConverter = this->getTypeConverter();
RankedTensorType resultType = RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()) typeConverter->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
Type dtype = resultType.getElementType(); Type dtype = resultType.getElementType();
Value start = Value start =
convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype); convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype);

View File

@ -138,17 +138,16 @@ public:
requires_grad = tensorFloatOp.getRequiresGrad(); requires_grad = tensorFloatOp.getRequiresGrad();
} }
// TODO: Dtype conversion. // TODO: Dtype conversion.
if (!dtype.getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(dtype.getType()))
return rewriter.notifyMatchFailure(op, "Unimplemented non-None dtype"); return rewriter.notifyMatchFailure(op, "Unimplemented non-None dtype");
// TODO: Device information. // TODO: Device information.
if (!device.getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(device.getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Unimplemented non-None device information"); op, "Unimplemented non-None device information");
RankedTensorType resultType = getTypeConverter() RankedTensorType resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
Type outElementType = resultType.getElementType(); Type outElementType = resultType.getElementType();
Value elemValProm = Value elemValProm =
convertScalarToDtype(rewriter, loc, elemVal, outElementType); convertScalarToDtype(rewriter, loc, elemVal, outElementType);
@ -171,9 +170,8 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure(); return failure();
Location loc = op.getLoc(); Location loc = op.getLoc();
RankedTensorType resultType = getTypeConverter() RankedTensorType resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
Type outElementType = resultType.getElementType(); Type outElementType = resultType.getElementType();
Value elemVal = adaptor.getA(); Value elemVal = adaptor.getA();
Value elemValProm = Value elemValProm =

View File

@ -422,7 +422,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto clone = dyn_cast<AtenCloneOp>(op)) { if (auto clone = dyn_cast<AtenCloneOp>(op)) {
int64_t memoryFormat; int64_t memoryFormat;
if (!clone.getMemoryFormat().getType().isa<Torch::NoneType>() && if (!isa<Torch::NoneType>(clone.getMemoryFormat().getType()) &&
(!matchPattern(clone.getMemoryFormat(), (!matchPattern(clone.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)) || m_TorchConstantInt(&memoryFormat)) ||
(memoryFormat != torch_upstream::MemoryFormat::Contiguous && (memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
@ -434,24 +434,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return payloadArgs[0]; return payloadArgs[0];
} }
if (auto bitwiseAndTensor = dyn_cast<AtenBitwiseAndTensorOp>(op)) { if (auto bitwiseAndTensor = dyn_cast<AtenBitwiseAndTensorOp>(op)) {
if (bitwiseAndTensor.getType() if (isa<mlir::FloatType>(
.cast<ValueTensorType>() cast<ValueTensorType>(bitwiseAndTensor.getType()).getDtype())) {
.getDtype()
.isa<mlir::FloatType>()) {
bitwiseAndTensor.emitError( bitwiseAndTensor.emitError(
"Bitwise_And does not support floating point dtype"); "Bitwise_And does not support floating point dtype");
return nullptr; return nullptr;
} }
Type dtype = converter->convertType(bitwiseAndTensor.getType()) Type dtype = cast<RankedTensorType>(
.cast<RankedTensorType>() converter->convertType(bitwiseAndTensor.getType()))
.getElementType(); .getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::AndIOp>(loc, lhs, rhs); return b.create<arith::AndIOp>(loc, lhs, rhs);
} }
if (auto bitwiseAndScalar = dyn_cast<AtenBitwiseAndScalarOp>(op)) { if (auto bitwiseAndScalar = dyn_cast<AtenBitwiseAndScalarOp>(op)) {
Type dtype = converter->convertType(bitwiseAndScalar.getType()) Type dtype = cast<RankedTensorType>(
.cast<RankedTensorType>() converter->convertType(bitwiseAndScalar.getType()))
.getElementType(); .getElementType();
if (!isa<mlir::IntegerType>(dtype)) { if (!isa<mlir::IntegerType>(dtype)) {
bitwiseAndScalar.emitError( bitwiseAndScalar.emitError(
@ -469,32 +467,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::AndIOp>(loc, self, other); return b.create<arith::AndIOp>(loc, self, other);
} }
if (auto bitwiseOrTensor = dyn_cast<AtenBitwiseOrTensorOp>(op)) { if (auto bitwiseOrTensor = dyn_cast<AtenBitwiseOrTensorOp>(op)) {
if (bitwiseOrTensor.getType() if (isa<mlir::FloatType>(
.cast<ValueTensorType>() cast<ValueTensorType>(bitwiseOrTensor.getType()).getDtype())) {
.getDtype()
.isa<mlir::FloatType>()) {
bitwiseOrTensor.emitError( bitwiseOrTensor.emitError(
"Bitwise_Or does not support floating point dtype"); "Bitwise_Or does not support floating point dtype");
return nullptr; return nullptr;
} }
Type dtype = converter->convertType(bitwiseOrTensor.getType()) Type dtype = cast<RankedTensorType>(
.cast<RankedTensorType>() converter->convertType(bitwiseOrTensor.getType()))
.getElementType(); .getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::OrIOp>(loc, lhs, rhs); return b.create<arith::OrIOp>(loc, lhs, rhs);
} }
if (auto bitwiseXorTensor = dyn_cast<AtenBitwiseXorTensorOp>(op)) { if (auto bitwiseXorTensor = dyn_cast<AtenBitwiseXorTensorOp>(op)) {
if (bitwiseXorTensor.getType() if (isa<mlir::FloatType>(
.cast<ValueTensorType>() cast<ValueTensorType>(bitwiseXorTensor.getType()).getDtype())) {
.getDtype()
.isa<mlir::FloatType>()) {
bitwiseXorTensor.emitError( bitwiseXorTensor.emitError(
"Bitwise_Xor does not support floating point dtype"); "Bitwise_Xor does not support floating point dtype");
return nullptr; return nullptr;
} }
Type dtype = converter->convertType(bitwiseXorTensor.getType()) Type dtype = cast<RankedTensorType>(
.cast<RankedTensorType>() converter->convertType(bitwiseXorTensor.getType()))
.getElementType(); .getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
@ -502,8 +496,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto bitwiseRightShiftTensor = if (auto bitwiseRightShiftTensor =
dyn_cast<AtenBitwiseRightShiftTensorOp>(op)) { dyn_cast<AtenBitwiseRightShiftTensorOp>(op)) {
Type dtype = converter->convertType(bitwiseRightShiftTensor.getType()) Type dtype = cast<RankedTensorType>(
.cast<RankedTensorType>() converter->convertType(bitwiseRightShiftTensor.getType()))
.getElementType(); .getElementType();
if (!isa<mlir::IntegerType>(dtype)) { if (!isa<mlir::IntegerType>(dtype)) {
bitwiseRightShiftTensor.emitError( bitwiseRightShiftTensor.emitError(
@ -516,8 +510,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto bitwiseLeftShiftTensor = if (auto bitwiseLeftShiftTensor =
dyn_cast<AtenBitwiseLeftShiftTensorOp>(op)) { dyn_cast<AtenBitwiseLeftShiftTensorOp>(op)) {
Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType()) Type dtype = cast<RankedTensorType>(
.cast<RankedTensorType>() converter->convertType(bitwiseLeftShiftTensor.getType()))
.getElementType(); .getElementType();
if (!isa<mlir::IntegerType>(dtype)) { if (!isa<mlir::IntegerType>(dtype)) {
bitwiseLeftShiftTensor.emitError( bitwiseLeftShiftTensor.emitError(
@ -557,7 +551,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return createEqual(b, loc, floatDtype, self, zero); return createEqual(b, loc, floatDtype, self, zero);
} }
if (isa<AtenAbsOp>(op)) { if (isa<AtenAbsOp>(op)) {
if (payloadArgs[0].getType().isa<IntegerType>()) if (isa<IntegerType>(payloadArgs[0].getType()))
return b.create<math::AbsIOp>(loc, payloadArgs[0]); return b.create<math::AbsIOp>(loc, payloadArgs[0]);
return b.create<math::AbsFOp>(loc, payloadArgs[0]); return b.create<math::AbsFOp>(loc, payloadArgs[0]);
} }
@ -653,20 +647,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::SelectOp>(loc, cmp, arg, zeroPoint); return b.create<arith::SelectOp>(loc, cmp, arg, zeroPoint);
} }
if (auto round = dyn_cast<AtenRoundOp>(op)) { if (auto round = dyn_cast<AtenRoundOp>(op)) {
if (!round.getType() if (!isa<mlir::FloatType>(
.cast<ValueTensorType>() cast<ValueTensorType>(round.getType()).getDtype())) {
.getDtype()
.isa<mlir::FloatType>()) {
round.emitError("unimplemented: non-floating point dtype"); round.emitError("unimplemented: non-floating point dtype");
return nullptr; return nullptr;
} }
return b.create<math::RoundEvenOp>(loc, payloadArgs[0]); return b.create<math::RoundEvenOp>(loc, payloadArgs[0]);
} }
if (auto prelu = dyn_cast<AtenPreluOp>(op)) { if (auto prelu = dyn_cast<AtenPreluOp>(op)) {
if (!prelu.getType() if (!isa<mlir::FloatType>(
.cast<ValueTensorType>() cast<ValueTensorType>(prelu.getType()).getDtype())) {
.getDtype()
.isa<mlir::FloatType>()) {
prelu.emitError("unimplemented: non-floating point dtype"); prelu.emitError("unimplemented: non-floating point dtype");
return nullptr; return nullptr;
} }
@ -685,10 +675,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::AddFOp>(loc, positivePart, scaledNegativePart); return b.create<arith::AddFOp>(loc, positivePart, scaledNegativePart);
} }
if (auto gelu = dyn_cast<AtenGeluOp>(op)) { if (auto gelu = dyn_cast<AtenGeluOp>(op)) {
if (!gelu.getType() if (!isa<mlir::FloatType>(
.cast<ValueTensorType>() cast<ValueTensorType>(gelu.getType()).getDtype())) {
.getDtype()
.isa<mlir::FloatType>()) {
gelu.emitError("unimplemented: non-floating point dtype"); gelu.emitError("unimplemented: non-floating point dtype");
return nullptr; return nullptr;
} }
@ -732,10 +720,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return nullptr; return nullptr;
} }
if (auto geluBackward = dyn_cast<AtenGeluBackwardOp>(op)) { if (auto geluBackward = dyn_cast<AtenGeluBackwardOp>(op)) {
if (!geluBackward.getType() if (!isa<mlir::FloatType>(
.cast<ValueTensorType>() cast<ValueTensorType>(geluBackward.getType()).getDtype())) {
.getDtype()
.isa<mlir::FloatType>()) {
geluBackward.emitError("unimplemented: non-floating point dtype"); geluBackward.emitError("unimplemented: non-floating point dtype");
return nullptr; return nullptr;
} }
@ -770,10 +756,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto hardtanhBackward = dyn_cast<AtenHardtanhBackwardOp>(op)) { if (auto hardtanhBackward = dyn_cast<AtenHardtanhBackwardOp>(op)) {
AtenHardtanhBackwardOp::Adaptor adaptor(operands); AtenHardtanhBackwardOp::Adaptor adaptor(operands);
if (!hardtanhBackward.getType() if (!isa<mlir::FloatType>(
.cast<ValueTensorType>() cast<ValueTensorType>(hardtanhBackward.getType()).getDtype())) {
.getDtype()
.isa<mlir::FloatType>()) {
hardtanhBackward.emitError("unimplemented: non-floating point dtype"); hardtanhBackward.emitError("unimplemented: non-floating point dtype");
return nullptr; return nullptr;
} }
@ -967,10 +951,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) { if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) {
if (!pow.getType() if (!isa<mlir::FloatType>(
.cast<ValueTensorType>() cast<ValueTensorType>(pow.getType()).getDtype())) {
.getDtype()
.isa<mlir::FloatType>()) {
pow.emitError("unimplemented: non-floating point dtype"); pow.emitError("unimplemented: non-floating point dtype");
return nullptr; return nullptr;
} }
@ -1047,10 +1029,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto lerp = dyn_cast<AtenLerpTensorOp>(op)) { if (auto lerp = dyn_cast<AtenLerpTensorOp>(op)) {
if (!lerp.getType() if (!isa<mlir::FloatType>(
.cast<ValueTensorType>() cast<ValueTensorType>(lerp.getType()).getDtype())) {
.getDtype()
.isa<mlir::FloatType>()) {
lerp.emitError("unimplemented: non-floating point dtype"); lerp.emitError("unimplemented: non-floating point dtype");
return nullptr; return nullptr;
} }
@ -1064,9 +1044,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto minimum = dyn_cast<AtenMinimumOp>(op)) { if (auto minimum = dyn_cast<AtenMinimumOp>(op)) {
Type dtype = cast<BaseTensorType>(minimum.getType()).getDtype(); Type dtype = cast<BaseTensorType>(minimum.getType()).getDtype();
Type elemTy = converter->convertType(minimum.getType()) Type elemTy =
.cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(minimum.getType()))
.getElementType(); .getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy);
Value pred = createLessThan(b, loc, dtype, lhs, rhs); Value pred = createLessThan(b, loc, dtype, lhs, rhs);
@ -1074,9 +1054,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto maximum = dyn_cast<AtenMaximumOp>(op)) { if (auto maximum = dyn_cast<AtenMaximumOp>(op)) {
Type dtype = cast<BaseTensorType>(maximum.getType()).getDtype(); Type dtype = cast<BaseTensorType>(maximum.getType()).getDtype();
Type elemTy = converter->convertType(maximum.getType()) Type elemTy =
.cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(maximum.getType()))
.getElementType(); .getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy);
Value pred = createGreaterThan(b, loc, dtype, lhs, rhs); Value pred = createGreaterThan(b, loc, dtype, lhs, rhs);
@ -1086,8 +1066,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
AtenClampOp::Adaptor adaptor(operands); AtenClampOp::Adaptor adaptor(operands);
auto min = adaptor.getMin(); auto min = adaptor.getMin();
auto max = adaptor.getMax(); auto max = adaptor.getMax();
if (min.getType().isa<Torch::OptionalType>() || if (isa<Torch::OptionalType>(min.getType()) ||
max.getType().isa<Torch::OptionalType>()) { isa<Torch::OptionalType>(max.getType())) {
clamp.emitError("unimplemented: runtime optional type"); clamp.emitError("unimplemented: runtime optional type");
return nullptr; return nullptr;
} }
@ -1125,9 +1105,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}; };
auto result = payloadArgs[0]; auto result = payloadArgs[0];
if (!min.getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(min.getType()))
result = cmpSelect(result, min, /*getMax=*/false); result = cmpSelect(result, min, /*getMax=*/false);
if (!max.getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(max.getType()))
result = cmpSelect(result, max, /*getMax=*/true); result = cmpSelect(result, max, /*getMax=*/true);
return result; return result;
} }
@ -1135,8 +1115,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
AtenClampTensorOp::Adaptor adaptor(operands); AtenClampTensorOp::Adaptor adaptor(operands);
auto min = adaptor.getMin(); auto min = adaptor.getMin();
auto max = adaptor.getMax(); auto max = adaptor.getMax();
if (min.getType().isa<Torch::OptionalType>() || if (isa<Torch::OptionalType>(min.getType()) ||
max.getType().isa<Torch::OptionalType>()) { isa<Torch::OptionalType>(max.getType())) {
clampTensor.emitError("unimplemented: runtime optional type"); clampTensor.emitError("unimplemented: runtime optional type");
return nullptr; return nullptr;
} }
@ -1145,7 +1125,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
.getElementType(); .getElementType();
bool isMinNone = true; bool isMinNone = true;
auto result = payloadArgs[0]; auto result = payloadArgs[0];
if (!min.getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(min.getType())) {
isMinNone = false; isMinNone = false;
auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype); auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value pred; Value pred;
@ -1163,7 +1143,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
result = b.create<arith::SelectOp>(loc, pred, minPromoted, result); result = b.create<arith::SelectOp>(loc, pred, minPromoted, result);
} }
if (!max.getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(max.getType())) {
max = isMinNone ? payloadArgs[1] : payloadArgs[2]; max = isMinNone ? payloadArgs[1] : payloadArgs[2];
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
Value pred; Value pred;
@ -1252,9 +1232,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::DivFOp>(loc, self, other); return b.create<arith::DivFOp>(loc, self, other);
} }
if (auto remScalar = dyn_cast<AtenRemainderScalarOp>(op)) { if (auto remScalar = dyn_cast<AtenRemainderScalarOp>(op)) {
Type newResultType = converter->convertType(remScalar.getType()) Type newResultType =
.cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(remScalar.getType()))
.getElementType(); .getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
Value other = convertScalarToDtype(b, loc, operands[1], newResultType); Value other = convertScalarToDtype(b, loc, operands[1], newResultType);
@ -1272,9 +1252,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return result; return result;
} }
if (auto remTensor = dyn_cast<AtenRemainderTensorOp>(op)) { if (auto remTensor = dyn_cast<AtenRemainderTensorOp>(op)) {
Type newResultType = converter->convertType(remTensor.getType()) Type newResultType =
.cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(remTensor.getType()))
.getElementType(); .getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
@ -1292,9 +1272,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return result; return result;
} }
if (auto fmod = dyn_cast<AtenFmodTensorOp>(op)) { if (auto fmod = dyn_cast<AtenFmodTensorOp>(op)) {
Type newResultType = converter->convertType(fmod.getType()) Type newResultType =
.cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(fmod.getType()))
.getElementType(); .getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
@ -1420,9 +1400,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto bitwiseNot = dyn_cast<AtenBitwiseNotOp>(op)) { if (auto bitwiseNot = dyn_cast<AtenBitwiseNotOp>(op)) {
Type elementType = converter->convertType(bitwiseNot.getType()) Type elementType =
.cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(bitwiseNot.getType()))
.getElementType(); .getElementType();
if (isa<mlir::FloatType>(elementType)) { if (isa<mlir::FloatType>(elementType)) {
bitwiseNot.emitError("Bitwise_Not does not support floating point dtype"); bitwiseNot.emitError("Bitwise_Not does not support floating point dtype");
return nullptr; return nullptr;
@ -1607,10 +1587,9 @@ public:
Location loc = op->getLoc(); Location loc = op->getLoc();
auto tensorOperands = llvm::to_vector<6>(llvm::make_filter_range( auto tensorOperands = llvm::to_vector<6>(llvm::make_filter_range(
operands, [](Value v) { return v.getType().isa<RankedTensorType>(); })); operands, [](Value v) { return isa<RankedTensorType>(v.getType()); }));
auto resultType = getTypeConverter() auto resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
bool hadErrorCreatingPayload = false; bool hadErrorCreatingPayload = false;
Value generic = torch_to_linalg::createElementwiseLinalgGeneric( Value generic = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, tensorOperands, resultType.getElementType(), rewriter, loc, tensorOperands, resultType.getElementType(),
@ -1657,7 +1636,7 @@ public:
return rewriter.notifyMatchFailure(op, "dim must be constant"); return rewriter.notifyMatchFailure(op, "dim must be constant");
// TODO: Incorporate the weight argument. // TODO: Incorporate the weight argument.
if (!weight.getType().isa<mlir::torch::Torch::NoneType>()) if (!isa<mlir::torch::Torch::NoneType>(weight.getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Unimplemented, the weight operand is not incorporated."); op, "Unimplemented, the weight operand is not incorporated.");
@ -1672,9 +1651,8 @@ public:
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected input and target to be rank <= 2"); op, "expected input and target to be rank <= 2");
} }
RankedTensorType resultType = getTypeConverter() RankedTensorType resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
Type elementType = resultType.getElementType(); Type elementType = resultType.getElementType();
Value zeroVal = rewriter.create<arith::ConstantOp>( Value zeroVal = rewriter.create<arith::ConstantOp>(
@ -1948,7 +1926,7 @@ public:
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
Value target = adaptor.getTarget(); Value target = adaptor.getTarget();
Value weight = adaptor.getWeight(); Value weight = adaptor.getWeight();
bool weightIsNone = op.getWeight().getType().isa<Torch::NoneType>(); bool weightIsNone = isa<Torch::NoneType>(op.getWeight().getType());
Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.getIgnoreIndex()); Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.getIgnoreIndex());
Value totalWeight = adaptor.getTotalWeight(); Value totalWeight = adaptor.getTotalWeight();
@ -2069,9 +2047,8 @@ public:
}) })
->getResult(0); ->getResult(0);
RankedTensorType resultType = getTypeConverter() RankedTensorType resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, gradInput); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, gradInput);
return success(); return success();
} }
@ -2214,9 +2191,8 @@ public:
LogicalResult LogicalResult
matchAndRewrite(TensorStaticInfoCastOp op, OpAdaptor adaptor, matchAndRewrite(TensorStaticInfoCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
RankedTensorType resultType = getTypeConverter() RankedTensorType resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
adaptor.getOperand()); adaptor.getOperand());
return success(); return success();
@ -2243,7 +2219,7 @@ public:
if (succeeded(checkNotNone(rewriter, op, eps))) if (succeeded(checkNotNone(rewriter, op, eps)))
handleEps = true; handleEps = true;
if (handleEps && !eps.getType().isa<mlir::FloatType>()) { if (handleEps && !isa<mlir::FloatType>(eps.getType())) {
op.emitError("Logit does not support non-floating point type"); op.emitError("Logit does not support non-floating point type");
return failure(); return failure();
} }
@ -2317,9 +2293,8 @@ public:
LogicalResult LogicalResult
matchAndRewrite(AtenIntReprOp op, OpAdaptor adaptor, matchAndRewrite(AtenIntReprOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
RankedTensorType resultType = getTypeConverter() RankedTensorType resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
adaptor.getSelf()); adaptor.getSelf());
return success(); return success();
@ -2362,8 +2337,8 @@ public:
zeropoint = converter->materializeTargetConversion( zeropoint = converter->materializeTargetConversion(
rewriter, loc, converter->convertType(zeropoint.getType()), zeropoint); rewriter, loc, converter->convertType(zeropoint.getType()), zeropoint);
auto resultType = converter->convertType(op->getResult(0).getType()) auto resultType = cast<RankedTensorType>(
.cast<RankedTensorType>(); converter->convertType(op->getResult(0).getType()));
llvm::SmallVector<Value> dynSizes; llvm::SmallVector<Value> dynSizes;
for (auto [index, dim] : llvm::enumerate(resultType.getShape())) { for (auto [index, dim] : llvm::enumerate(resultType.getShape())) {
@ -2553,9 +2528,8 @@ public:
return res; return res;
}; };
auto resultType = getTypeConverter() auto resultType = cast<RankedTensorType>(
->convertType(op.getResult().getType()) getTypeConverter()->convertType(op.getResult().getType()));
.cast<RankedTensorType>();
SmallVector<Value> resultSize{}; SmallVector<Value> resultSize{};
if (resultType.isDynamicDim(0)) if (resultType.isDynamicDim(0))
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, input, 0)); resultSize.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
@ -2675,7 +2649,7 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
SmallVector<Value> scaleValues, SmallVector<Value> scaleValues,
std::string coordStr) { std::string coordStr) {
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank(); auto inputRank = inputType.getRank();
SmallVector<Value> indices; SmallVector<Value> indices;
@ -2725,7 +2699,7 @@ static Value BilinearInterpolate(OpBuilder &b,
SmallVector<Value> scaleValues, SmallVector<Value> scaleValues,
std::string coordStr) { std::string coordStr) {
unsigned dimOffset = 2; unsigned dimOffset = 2;
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank(); auto inputRank = inputType.getRank();
Value cstOneEps = Value cstOneEps =
@ -2877,7 +2851,7 @@ public:
Location loc = op->getLoc(); Location loc = op->getLoc();
Value input = adaptor.getInput(); Value input = adaptor.getInput();
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank(); auto inputRank = inputType.getRank();
if (mode.substr(0, 8) == "bilinear" && inputRank != 4) if (mode.substr(0, 8) == "bilinear" && inputRank != 4)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -2893,7 +2867,7 @@ public:
loc, rewriter.getIntegerType(64), inputSize)); loc, rewriter.getIntegerType(64), inputSize));
} }
if (!op.getScaleFactor().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getScaleFactor().getType())) {
bool recompScale; bool recompScale;
if (!matchPattern(op.getRecomputeScaleFactor(), if (!matchPattern(op.getRecomputeScaleFactor(),
m_TorchConstantBool(&recompScale))) m_TorchConstantBool(&recompScale)))

View File

@ -52,7 +52,7 @@ Value torch_to_linalg::getPaddedTensor(
Value torch_to_linalg::getZeroPaddedTensor( Value torch_to_linalg::getZeroPaddedTensor(
Operation *op, OpBuilder &b, Value &input, Operation *op, OpBuilder &b, Value &input,
SmallVectorImpl<int64_t> &paddingInts) { SmallVectorImpl<int64_t> &paddingInts) {
assert(input.getType().isa<RankedTensorType>() && assert(isa<RankedTensorType>(input.getType()) &&
"input must be RankedTensorType"); "input must be RankedTensorType");
Location loc = op->getLoc(); Location loc = op->getLoc();
Value c0 = b.create<arith::ConstantOp>( Value c0 = b.create<arith::ConstantOp>(
@ -67,7 +67,7 @@ Value torch_to_linalg::getZeroPaddedTensor(
Value torch_to_linalg::getDynamicZeroPaddedTensor( Value torch_to_linalg::getDynamicZeroPaddedTensor(
Operation *op, OpBuilder &b, Value &input, SmallVectorImpl<Value> &padding, Operation *op, OpBuilder &b, Value &input, SmallVectorImpl<Value> &padding,
int unpaddedDims, Value pad) { int unpaddedDims, Value pad) {
assert(input.getType().isa<RankedTensorType>() && assert(isa<RankedTensorType>(input.getType()) &&
"input must be RankedTensorType"); "input must be RankedTensorType");
unsigned int inRank = cast<RankedTensorType>(input.getType()).getRank(); unsigned int inRank = cast<RankedTensorType>(input.getType()).getRank();
Location loc = op->getLoc(); Location loc = op->getLoc();

View File

@ -252,7 +252,7 @@ public:
// "block" arguments // "block" arguments
for (const auto &barg : enumerate(op.getRegion().front().getArguments())) { for (const auto &barg : enumerate(op.getRegion().front().getArguments())) {
Value to = block->getArgument(barg.index()); Value to = block->getArgument(barg.index());
if (to.getType().isa<mlir::IndexType>()) if (isa<mlir::IndexType>(to.getType()))
to = to =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI64Type(), to); rewriter.create<arith::IndexCastOp>(loc, rewriter.getI64Type(), to);
Type targetType = to.getType(); Type targetType = to.getType();

View File

@ -146,9 +146,9 @@ public:
if (!selfType) { if (!selfType) {
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in StableHLO");
} }
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter() auto outType = cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template cast<TensorType>(); op.getType()));
self = hlo::promoteType(rewriter, op.getLoc(), self, outType); self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self); rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
return success(); return success();
@ -203,9 +203,9 @@ public:
auto selfTy = cast<TensorType>(self.getType()); auto selfTy = cast<TensorType>(self.getType());
if (!selfTy) if (!selfTy)
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in StableHLO");
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter() auto resultTy = cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template cast<TensorType>(); op.getType()));
if (isa<mlir::FloatType>(resultTy.getElementType())) { if (isa<mlir::FloatType>(resultTy.getElementType())) {
Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy); Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy);
@ -231,9 +231,9 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter() auto outType = dyn_cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template dyn_cast<TensorType>(); op.getType()));
if (!outType) if (!outType)
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in StableHLO");
@ -321,9 +321,9 @@ public:
if (!lhsTy || !rhsTy) if (!lhsTy || !rhsTy)
return op.emitError("only Tensor types supported"); return op.emitError("only Tensor types supported");
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter() auto outTy = cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template cast<TensorType>(); op.getType()));
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
@ -354,9 +354,9 @@ public:
if (!lhsType) if (!lhsType)
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in StableHLO");
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter() TensorType outType = cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template cast<TensorType>(); op.getType()));
Type outElemTy = outType.getElementType(); Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) { if (!outElemTy.isIntOrFloat()) {
@ -607,9 +607,9 @@ public:
if (!lhsTy) if (!lhsTy)
return op.emitError("lhs must be a ranked tensor type"); return op.emitError("lhs must be a ranked tensor type");
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter() TensorType outType = cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template cast<TensorType>(); op.getType()));
Type outElemTy = outType.getElementType(); Type outElemTy = outType.getElementType();
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
if (!rhsTy) { if (!rhsTy) {
@ -917,9 +917,9 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
if (!lhsType) if (!lhsType)
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in StableHLO");
auto outType = OpConversionPattern<AtenPowTensorScalarOp>::getTypeConverter() auto outType = cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenPowTensorScalarOp>::getTypeConverter()
.template cast<TensorType>(); ->convertType(op.getType()));
Type outElemTy = outType.getElementType(); Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) { if (!outElemTy.isIntOrFloat()) {
@ -1421,9 +1421,9 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
// Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp. // Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp.
SmallVector<APFloat> zeroConstVec( SmallVector<APFloat> zeroConstVec(
numFeatureDimSize, APFloat::getZero(inputTy.getElementType() numFeatureDimSize,
.cast<mlir::FloatType>() APFloat::getZero(
.getFloatSemantics())); cast<mlir::FloatType>(inputTy.getElementType()).getFloatSemantics()));
SmallVector<APFloat> oneConstVec( SmallVector<APFloat> oneConstVec(
numFeatureDimSize, numFeatureDimSize,
APFloat( APFloat(
@ -1633,9 +1633,8 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
Location loc = op->getLoc(); Location loc = op->getLoc();
// Get element type of resultType as dtype // Get element type of resultType as dtype
auto outType = this->getTypeConverter() auto outType = cast<RankedTensorType>(
->convertType(op.getType()) this->getTypeConverter()->convertType(op.getType()));
.cast<RankedTensorType>();
auto dtype = outType.getElementType(); auto dtype = outType.getElementType();
if (!isa<mlir::IntegerType>(dtype) && !isa<mlir::FloatType>(dtype)) { if (!isa<mlir::IntegerType>(dtype) && !isa<mlir::FloatType>(dtype)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -1678,7 +1677,7 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
AtenConstantPadNdOp op, OpAdaptor adaptor, AtenConstantPadNdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<RankedTensorType>(); auto selfTy = cast<RankedTensorType>(self.getType());
auto selfElemTy = selfTy.getElementType(); auto selfElemTy = selfTy.getElementType();
int64_t rank = selfTy.getRank(); int64_t rank = selfTy.getRank();
@ -2029,7 +2028,7 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<RankedTensorType>(); auto selfTy = cast<RankedTensorType>(self.getType());
if (!selfTy.hasStaticShape()) { if (!selfTy.hasStaticShape()) {
return op->emitError("dynamic shaped input is not supported"); return op->emitError("dynamic shaped input is not supported");
} }
@ -2062,7 +2061,7 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
cmpTypeAttr); cmpTypeAttr);
auto resTy = auto resTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto bcastTy = resTy.clone(rewriter.getI1Type()); auto bcastTy = resTy.clone(rewriter.getI1Type());
auto bcastAttr = rewriter.getDenseI64ArrayAttr({selfRank - 2, selfRank - 1}); auto bcastAttr = rewriter.getDenseI64ArrayAttr({selfRank - 2, selfRank - 1});
@ -2071,15 +2070,15 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
auto resElemTy = resTy.getElementType(); auto resElemTy = resTy.getElementType();
Value zeroTensor; Value zeroTensor;
if (resElemTy.isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(resElemTy)) {
auto constAttr = SplatElementsAttr::get( auto constAttr = SplatElementsAttr::get(
resTy, llvm::APFloat::getZero( resTy, llvm::APFloat::getZero(
resElemTy.cast<FloatType>().getFloatSemantics(), false)); cast<FloatType>(resElemTy).getFloatSemantics(), false));
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr); zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
} else if (resElemTy.isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(resElemTy)) {
auto constAttr = SplatElementsAttr::get( auto constAttr = SplatElementsAttr::get(
resTy, resTy,
llvm::APInt::getZero(resElemTy.cast<mlir::IntegerType>().getWidth())); llvm::APInt::getZero(cast<mlir::IntegerType>(resElemTy).getWidth()));
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr); zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
} else { } else {
return op.emitError("element type is not float or integer"); return op.emitError("element type is not float or integer");

View File

@ -157,8 +157,8 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
Value builtinTypeStart = adaptor.getStart(); Value builtinTypeStart = adaptor.getStart();
Value builtinTypeEnd = adaptor.getEnd(); Value builtinTypeEnd = adaptor.getEnd();
if (torchTypeStart.getType().isa<OptionalType>() || if (isa<OptionalType>(torchTypeStart.getType()) ||
torchTypeEnd.getType().isa<OptionalType>()) isa<OptionalType>(torchTypeEnd.getType()))
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg"); return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
int64_t step; int64_t step;
@ -349,11 +349,11 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "offsets must be a vector with static shape equal to 1"); op, "offsets must be a vector with static shape equal to 1");
if (!op.getPaddingIdx().getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(op.getPaddingIdx().getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Unimplemented: padding_idx should be none"); op, "Unimplemented: padding_idx should be none");
if (!op.getPerSampleWeights().getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(op.getPerSampleWeights().getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Unimplemented: per_sample_weights should be none"); op, "Unimplemented: per_sample_weights should be none");
@ -453,25 +453,22 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
loc, getTypeConverter()->convertType(op.getType(0)), loc, getTypeConverter()->convertType(op.getType(0)),
stablehloReduceOp.getResult(0), outShapeTensor); stablehloReduceOp.getResult(0), outShapeTensor);
RankedTensorType resultType = getTypeConverter() RankedTensorType resultType = cast<RankedTensorType>(
->convertType(op->getResult(1).getType()) getTypeConverter()->convertType(op->getResult(1).getType()));
.cast<RankedTensorType>();
Value resultB = Value resultB =
createInitialValueForGatherScatterOp(op, resultType, rewriter); createInitialValueForGatherScatterOp(op, resultType, rewriter);
if (!resultB) if (!resultB)
return failure(); return failure();
resultType = getTypeConverter() resultType = cast<RankedTensorType>(
->convertType(op->getResult(2).getType()) getTypeConverter()->convertType(op->getResult(2).getType()));
.cast<RankedTensorType>();
Value resultC = Value resultC =
createInitialValueForGatherScatterOp(op, resultType, rewriter); createInitialValueForGatherScatterOp(op, resultType, rewriter);
if (!resultC) if (!resultC)
return failure(); return failure();
resultType = getTypeConverter() resultType = cast<RankedTensorType>(
->convertType(op->getResult(3).getType()) getTypeConverter()->convertType(op->getResult(3).getType()));
.cast<RankedTensorType>();
Value resultD = Value resultD =
createInitialValueForGatherScatterOp(op, resultType, rewriter); createInitialValueForGatherScatterOp(op, resultType, rewriter);
if (!resultD) if (!resultD)
@ -612,9 +609,8 @@ LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
auto input = adaptor.getSelf(); auto input = adaptor.getSelf();
RankedTensorType resultType = RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()) typeConverter->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
SmallVector<Value> resultShape; SmallVector<Value> resultShape;
SmallVector<Value> offsets; SmallVector<Value> offsets;

View File

@ -350,9 +350,9 @@ public:
rewriter.replaceOpWithNewOp<tensor::CastOp>( rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, op,
ConvertAtenOp<AtenOpT>::getTypeConverter() cast<RankedTensorType>(
->convertType(op.getType()) ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(
.template cast<RankedTensorType>(), op.getType())),
output); output);
return success(); return success();
@ -730,9 +730,8 @@ public:
// If transposed is set to true, // If transposed is set to true,
// the weight shape changes to [IC, (OC//G), KH, KW] // the weight shape changes to [IC, (OC//G), KH, KW]
auto weightTy = cast<RankedTensorType>(weight.getType()); auto weightTy = cast<RankedTensorType>(weight.getType());
auto outTy = getTypeConverter() auto outTy =
->convertType(op.getType()) cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
.template cast<RankedTensorType>();
if (!inputTy || !weightTy || !outTy) { if (!inputTy || !weightTy || !outTy) {
return op.emitError("input, weight and output must be ranked tensors"); return op.emitError("input, weight and output must be ranked tensors");
} }

View File

@ -216,10 +216,10 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
auto *secondIdxArg = std::next(secondValArg); auto *secondIdxArg = std::next(secondValArg);
stablehlo::ComparisonTypeAttr compareTypeAttr; stablehlo::ComparisonTypeAttr compareTypeAttr;
if (inputTy.getElementType().isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(inputTy.getElementType())) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get( compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::FLOAT); rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(inputTy.getElementType())) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get( compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::SIGNED); rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
} }
@ -395,9 +395,8 @@ public:
RankedTensorType inputTy = cast<RankedTensorType>(input.getType()); RankedTensorType inputTy = cast<RankedTensorType>(input.getType());
Type inputElemTy = inputTy.getElementType(); Type inputElemTy = inputTy.getElementType();
int64_t inputRank = inputTy.getRank(); int64_t inputRank = inputTy.getRank();
RankedTensorType outTy = ConvertAtenOp<AtenOpT>::getTypeConverter() RankedTensorType outTy = cast<RankedTensorType>(
->convertType(op.getType()) ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType()));
.template cast<RankedTensorType>();
auto outShape = outTy.getShape(); auto outShape = outTy.getShape();
if (inputRank <= Dim) { if (inputRank <= Dim) {

View File

@ -242,10 +242,10 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
auto *secondIdxArg = std::next(secondValArg); auto *secondIdxArg = std::next(secondValArg);
stablehlo::ComparisonTypeAttr compareTypeAttr; stablehlo::ComparisonTypeAttr compareTypeAttr;
if (inputTy.getElementType().isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(inputTy.getElementType())) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get( compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::FLOAT); rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) { } else if (isa<mlir::IntegerType>(inputTy.getElementType())) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get( compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::SIGNED); rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
} }
@ -535,12 +535,10 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
"AtenMaxDimOp to StableHLO"); "AtenMaxDimOp to StableHLO");
} }
RankedTensorType valResultType = getTypeConverter() RankedTensorType valResultType = cast<RankedTensorType>(
->convertType(op.getResult(0).getType()) getTypeConverter()->convertType(op.getResult(0).getType()));
.template cast<RankedTensorType>(); RankedTensorType idxResultType = cast<RankedTensorType>(
RankedTensorType idxResultType = getTypeConverter() getTypeConverter()->convertType(op.getResult(1).getType()));
->convertType(op.getResult(1).getType())
.template cast<RankedTensorType>();
Type idxElementType = idxResultType.getElementType(); Type idxElementType = idxResultType.getElementType();
if (!isa<mlir::IntegerType>(idxElementType)) { if (!isa<mlir::IntegerType>(idxElementType)) {
return op.emitError("Aten.max.dim needs integer-like result"); return op.emitError("Aten.max.dim needs integer-like result");
@ -636,9 +634,8 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType()); auto inputTy = dyn_cast<RankedTensorType>(input.getType());
auto outTy = getTypeConverter() auto outTy =
->convertType(op.getType()) dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
.template dyn_cast<RankedTensorType>();
if (!inputTy) { if (!inputTy) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO"); op, "only Tensor types supported in StableHLO");

View File

@ -271,7 +271,7 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "dim is statically invalid"); return rewriter.notifyMatchFailure(op, "dim is statically invalid");
auto getOptionalVal = [&](Value val) -> std::optional<Value> { auto getOptionalVal = [&](Value val) -> std::optional<Value> {
if (val.getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(val.getType())) {
return std::nullopt; return std::nullopt;
} else { } else {
return val; return val;
@ -451,7 +451,7 @@ template <>
LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite( LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite(
PrimsSplitDimOp op, OpAdaptor adaptor, PrimsSplitDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>(); auto selfType = dyn_cast<TensorType>(adaptor.getA().getType());
if (!selfType) { if (!selfType) {
return op.emitError("only tensor types are currently supported"); return op.emitError("only tensor types are currently supported");
} }

View File

@ -292,7 +292,7 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc,
arith::CmpIPredicate predicate = isDescending ? ge : le; arith::CmpIPredicate predicate = isDescending ? ge : le;
compareOp = rewriter.create<arith::CmpIOp>( compareOp = rewriter.create<arith::CmpIOp>(
loc, predicate, block->getArgument(0), block->getArgument(1)); loc, predicate, block->getArgument(0), block->getArgument(1));
} else if (elementTypes[0].isa<mlir::FloatType>()) { } else if (isa<mlir::FloatType>(elementTypes[0])) {
// Case for using arith::CmpFOp. // Case for using arith::CmpFOp.
arith::CmpFPredicate predicate = arith::CmpFPredicate predicate =
isDescending ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OLE; isDescending ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OLE;
@ -349,8 +349,8 @@ public:
b.create<TMTensor::YieldOp>(loc, updatesElement); b.create<TMTensor::YieldOp>(loc, updatesElement);
}); });
auto resultType = typeConverter->convertType(op->getResult(0).getType()) auto resultType = cast<RankedTensorType>(
.cast<RankedTensorType>(); typeConverter->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
return success(); return success();
} }
@ -381,7 +381,7 @@ public:
// Check whether the input is a 1-d tensor of integer type or not. // Check whether the input is a 1-d tensor of integer type or not.
RankedTensorType inputType = cast<RankedTensorType>(input.getType()); RankedTensorType inputType = cast<RankedTensorType>(input.getType());
if (inputType.getRank() != 1 || if (inputType.getRank() != 1 ||
!inputType.getElementType().isa<mlir::IntegerType>()) !isa<mlir::IntegerType>(inputType.getElementType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, op,
"Input tensor has to be a one-dimensional tensor of integer type."); "Input tensor has to be a one-dimensional tensor of integer type.");
@ -395,7 +395,7 @@ public:
"Unimplemented: Integer width not equal to 64 are not supported."); "Unimplemented: Integer width not equal to 64 are not supported.");
// TODO: Incorporate the weight argument. // TODO: Incorporate the weight argument.
if (!weights.getType().isa<mlir::torch::Torch::NoneType>()) if (!isa<mlir::torch::Torch::NoneType>(weights.getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Unimplemented: the weights operand is not incorporated."); op, "Unimplemented: the weights operand is not incorporated.");
@ -439,8 +439,8 @@ public:
indices = typeConverter->materializeTargetConversion( indices = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(indices.getType()), indices); rewriter, loc, typeConverter->convertType(indices.getType()), indices);
auto resultType = typeConverter->convertType(op->getResult(0).getType()) auto resultType = cast<RankedTensorType>(
.cast<RankedTensorType>(); typeConverter->convertType(op->getResult(0).getType()));
Type resultElemType = resultType.getElementType(); Type resultElemType = resultType.getElementType();
SmallVector<Value, 1> inputSizeDynamic = SmallVector<Value, 1> inputSizeDynamic =
@ -686,8 +686,8 @@ public:
auto valuesType = cast<ValueTensorType>(values.getType()); auto valuesType = cast<ValueTensorType>(values.getType());
int64_t inputRank = inputType.getSizes().size(); int64_t inputRank = inputType.getSizes().size();
auto valuesTensorType = cast<BaseTensorType>(op.getValues().getType()); auto valuesTensorType = cast<BaseTensorType>(op.getValues().getType());
auto resultType = typeConverter->convertType(op->getResult(0).getType()) auto resultType = cast<RankedTensorType>(
.cast<RankedTensorType>(); typeConverter->convertType(op->getResult(0).getType()));
if (!valuesTensorType.hasSizes()) if (!valuesTensorType.hasSizes())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -823,10 +823,10 @@ public:
Value inputElement) { Value inputElement) {
Value yieldValue = valuesElement; Value yieldValue = valuesElement;
if (accumulate) { if (accumulate) {
if (inputElement.getType().isa<mlir::IntegerType>()) { if (isa<mlir::IntegerType>(inputElement.getType())) {
yieldValue = yieldValue =
b.create<arith::AddIOp>(loc, inputElement, valuesElement); b.create<arith::AddIOp>(loc, inputElement, valuesElement);
} else if (inputElement.getType().isa<mlir::FloatType>()) { } else if (isa<mlir::FloatType>(inputElement.getType())) {
yieldValue = yieldValue =
b.create<arith::AddFOp>(loc, inputElement, valuesElement); b.create<arith::AddFOp>(loc, inputElement, valuesElement);
} else { } else {
@ -1042,10 +1042,10 @@ public:
[&](OpBuilder &b, Location loc, Value valuesElement, [&](OpBuilder &b, Location loc, Value valuesElement,
Value inputElement) { Value inputElement) {
Value yieldValue = valuesElement; Value yieldValue = valuesElement;
if (inputElement.getType().isa<mlir::IntegerType>()) { if (isa<mlir::IntegerType>(inputElement.getType())) {
yieldValue = yieldValue =
b.create<arith::AddIOp>(loc, inputElement, valuesElement); b.create<arith::AddIOp>(loc, inputElement, valuesElement);
} else if (inputElement.getType().isa<mlir::FloatType>()) { } else if (isa<mlir::FloatType>(inputElement.getType())) {
yieldValue = yieldValue =
b.create<arith::AddFOp>(loc, inputElement, valuesElement); b.create<arith::AddFOp>(loc, inputElement, valuesElement);
} else { } else {
@ -1204,33 +1204,33 @@ public:
Value result; Value result;
if (reduceEnum == torch_upstream::ReductionType::SUM || if (reduceEnum == torch_upstream::ReductionType::SUM ||
reduceEnum == torch_upstream::ReductionType::MEAN) { reduceEnum == torch_upstream::ReductionType::MEAN) {
if (update.getType().isa<mlir::IntegerType>()) { if (isa<mlir::IntegerType>(update.getType())) {
result = b.create<arith::AddIOp>(loc, update, current); result = b.create<arith::AddIOp>(loc, update, current);
} else if (update.getType().isa<mlir::FloatType>()) { } else if (isa<mlir::FloatType>(update.getType())) {
result = b.create<arith::AddFOp>(loc, update, current); result = b.create<arith::AddFOp>(loc, update, current);
} else { } else {
llvm_unreachable("Only integer/float types supported!"); llvm_unreachable("Only integer/float types supported!");
} }
} else if (reduceEnum == torch_upstream::ReductionType::PROD) { } else if (reduceEnum == torch_upstream::ReductionType::PROD) {
if (update.getType().isa<mlir::IntegerType>()) { if (isa<mlir::IntegerType>(update.getType())) {
result = b.create<arith::MulIOp>(loc, update, current); result = b.create<arith::MulIOp>(loc, update, current);
} else if (update.getType().isa<mlir::FloatType>()) { } else if (isa<mlir::FloatType>(update.getType())) {
result = b.create<arith::MulFOp>(loc, update, current); result = b.create<arith::MulFOp>(loc, update, current);
} else { } else {
llvm_unreachable("Only integer/float types supported!"); llvm_unreachable("Only integer/float types supported!");
} }
} else if (reduceEnum == torch_upstream::ReductionType::MAX) { } else if (reduceEnum == torch_upstream::ReductionType::MAX) {
if (update.getType().isa<mlir::IntegerType>()) { if (isa<mlir::IntegerType>(update.getType())) {
result = b.create<arith::MaxSIOp>(loc, update, current); result = b.create<arith::MaxSIOp>(loc, update, current);
} else if (update.getType().isa<mlir::FloatType>()) { } else if (isa<mlir::FloatType>(update.getType())) {
result = b.create<arith::MaximumFOp>(loc, update, current); result = b.create<arith::MaximumFOp>(loc, update, current);
} else { } else {
llvm_unreachable("Only integer/float types supported!"); llvm_unreachable("Only integer/float types supported!");
} }
} else if (reduceEnum == torch_upstream::ReductionType::MIN) { } else if (reduceEnum == torch_upstream::ReductionType::MIN) {
if (update.getType().isa<mlir::IntegerType>()) { if (isa<mlir::IntegerType>(update.getType())) {
result = b.create<arith::MinSIOp>(loc, update, current); result = b.create<arith::MinSIOp>(loc, update, current);
} else if (update.getType().isa<mlir::FloatType>()) { } else if (isa<mlir::FloatType>(update.getType())) {
result = b.create<arith::MinimumFOp>(loc, update, current); result = b.create<arith::MinimumFOp>(loc, update, current);
} else { } else {
llvm_unreachable("Only integer/float types supported!"); llvm_unreachable("Only integer/float types supported!");
@ -1285,9 +1285,8 @@ public:
}) })
.getResult()[0]; .getResult()[0];
} }
auto resultType = getTypeConverter() auto resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
return success(); return success();
@ -1392,9 +1391,8 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto resultType = getTypeConverter() auto resultType = cast<RankedTensorType>(
->convertType(op->getResult(0).getType()) getTypeConverter()->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
Type elementType = resultType.getElementType(); Type elementType = resultType.getElementType();
Type inputElementType = Type inputElementType =
cast<RankedTensorType>(input.getType()).getElementType(); cast<RankedTensorType>(input.getType()).getElementType();
@ -1414,7 +1412,7 @@ public:
int64_t inputRank = resultType.getRank(); int64_t inputRank = resultType.getRank();
Value dtype = op.getDtype(); Value dtype = op.getDtype();
if (!dtype.getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(dtype.getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unsupported: dtype argument not supported"); op, "unsupported: dtype argument not supported");
@ -1444,7 +1442,7 @@ public:
rewriter, loc, input, output, acc, dim, /*inclusive=*/true, rewriter, loc, input, output, acc, dim, /*inclusive=*/true,
[](OpBuilder &b, Location loc, Value input, Value acc) { [](OpBuilder &b, Location loc, Value input, Value acc) {
Value sum = Value sum =
(input.getType().isa<mlir::FloatType>() (isa<mlir::FloatType>(input.getType())
? b.create<arith::AddFOp>(loc, input, acc)->getResult(0) ? b.create<arith::AddFOp>(loc, input, acc)->getResult(0)
: b.create<arith::AddIOp>(loc, input, acc)->getResult(0)); : b.create<arith::AddIOp>(loc, input, acc)->getResult(0));
b.create<TMTensor::YieldOp>(loc, sum); b.create<TMTensor::YieldOp>(loc, sum);
@ -1472,7 +1470,7 @@ public:
cast<ShapedType>(adaptor.getQuery().getType()).getElementType(); cast<ShapedType>(adaptor.getQuery().getType()).getElementType();
// Verify inputs (only support defaults) // Verify inputs (only support defaults)
if (!mask.getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(mask.getType()))
return rewriter.notifyMatchFailure(op.getLoc(), return rewriter.notifyMatchFailure(op.getLoc(),
"attention masking not supported"); "attention masking not supported");
double dropout; double dropout;
@ -1483,7 +1481,7 @@ public:
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op.getLoc(), "causal attention masking not supported"); op.getLoc(), "causal attention masking not supported");
if (!scale.getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(scale.getType())) {
double scaleFloat; double scaleFloat;
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) || if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
scaleFloat != 1.0) scaleFloat != 1.0)

View File

@ -1,5 +1,5 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ////
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
@ -47,7 +47,7 @@ public:
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Only Tensor types supported in TOSA"); "Only Tensor types supported in TOSA");
if (selfTy.getElementType().isa<mlir::FloatType>()) { if (isa<mlir::FloatType>(selfTy.getElementType())) {
rewriter.replaceOpWithNewOp<TosaOpT>( rewriter.replaceOpWithNewOp<TosaOpT>(
op, op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
@ -99,9 +99,9 @@ public:
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Only Tensor types supported in TOSA"); "Only Tensor types supported in TOSA");
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter() auto outTy = cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template cast<TensorType>(); op.getType()));
auto binaryOp = auto binaryOp =
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs); tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
@ -248,9 +248,9 @@ public:
} }
// Get output type: tensor<i32/i64/f32> // Get output type: tensor<i32/i64/f32>
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter() auto outType = cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template cast<TensorType>(); op.getType()));
Type outElemTy = outType.getElementType(); Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) { if (!outElemTy.isIntOrFloat()) {
@ -373,9 +373,9 @@ public:
std::is_same<AtenOpT, AtenLtScalarOp>()); std::is_same<AtenOpT, AtenLtScalarOp>());
// Promote lhs and rhs dtypes for bitwise operators. // Promote lhs and rhs dtypes for bitwise operators.
TensorType resultTy = OpConversionPattern<AtenOpT>::getTypeConverter() TensorType resultTy = cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template cast<TensorType>(); op.getType()));
if (isBitwiseOp) { if (isBitwiseOp) {
lhs = tosa::promoteType(rewriter, lhs, resultTy); lhs = tosa::promoteType(rewriter, lhs, resultTy);
rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy); rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy);
@ -416,9 +416,9 @@ public:
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Only Tensor types supported in TOSA"); "Only Tensor types supported in TOSA");
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter() auto outType = cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template cast<TensorType>(); op.getType()));
Type outElemTy = outType.getElementType(); Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) if (!outElemTy.isIntOrFloat())
@ -444,9 +444,9 @@ public:
} }
if (isa<mlir::FloatType>(outElemTy) || isa<mlir::IntegerType>(outElemTy)) { if (isa<mlir::FloatType>(outElemTy) || isa<mlir::IntegerType>(outElemTy)) {
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter() auto outType = cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template cast<TensorType>(); op.getType()));
auto mulOp = tosa::createMulOpAndCast(rewriter, op, outType, lhs, auto mulOp = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
rhsTensor, /*shift=*/0); rhsTensor, /*shift=*/0);
@ -492,9 +492,9 @@ public:
"conversion in TOSA operation"); "conversion in TOSA operation");
} }
auto rhsTensor = rhsTy ? rhs : rhsAsTensor; auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter() auto outType = cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template cast<TensorType>(); op.getType()));
// auto result; // auto result;
Value result; Value result;
@ -540,7 +540,7 @@ LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = cast<TensorType>(self.getType()); auto selfTy = cast<TensorType>(self.getType());
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) { if (selfTy && isa<mlir::FloatType>(selfTy.getElementType())) {
rewriter.replaceOpWithNewOp<tosa::TanhOp>( rewriter.replaceOpWithNewOp<tosa::TanhOp>(
op, getTypeConverter()->convertType(op.getType()), self); op, getTypeConverter()->convertType(op.getType()), self);
return success(); return success();
@ -557,7 +557,7 @@ LogicalResult ConvertAtenOp<AtenSigmoidOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = cast<TensorType>(self.getType()); auto selfTy = cast<TensorType>(self.getType());
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) { if (selfTy && isa<mlir::FloatType>(selfTy.getElementType())) {
rewriter.replaceOpWithNewOp<tosa::SigmoidOp>( rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(
op, getTypeConverter()->convertType(op.getType()), self); op, getTypeConverter()->convertType(op.getType()), self);
return success(); return success();
@ -584,7 +584,7 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
} }
// Rescale the clampIn for quantized types. TBD // Rescale the clampIn for quantized types. TBD
if (!selfTy.getElementType().isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(selfTy.getElementType())) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization currently supported"); op, "Only floating-point datatype legalization currently supported");
} }
@ -604,7 +604,7 @@ LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = cast<TensorType>(self.getType()); auto selfTy = cast<TensorType>(self.getType());
if (!selfTy.getElementType().isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(selfTy.getElementType())) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization currently supported"); op, "Only floating-point datatype legalization currently supported");
} }
@ -667,9 +667,9 @@ public:
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Only Tensor types supported in TOSA"); "Only Tensor types supported in TOSA");
auto outputTy = OpConversionPattern<AtenOpT>::getTypeConverter() auto outputTy = cast<RankedTensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template cast<RankedTensorType>(); op.getType()));
if (!outputTy) if (!outputTy)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only ranked tensor type outputs permitted for reduce_mean"); op, "Only ranked tensor type outputs permitted for reduce_mean");
@ -828,9 +828,8 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "non-const keepdim parameter unsupported"); op, "non-const keepdim parameter unsupported");
auto resultTy = getTypeConverter() auto resultTy = cast<RankedTensorType>(
->convertType(op.getResult().getType()) getTypeConverter()->convertType(op.getResult().getType()));
.cast<RankedTensorType>();
auto outputETy = resultTy.getElementType(); auto outputETy = resultTy.getElementType();
// Create a single instance of tosa.argmax. // Create a single instance of tosa.argmax.
@ -927,9 +926,9 @@ public:
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Squeeze could not compute new shape"); "Squeeze could not compute new shape");
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter() auto resultTy = cast<RankedTensorType>(
->convertType(op.getResult().getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template cast<RankedTensorType>(); op.getResult().getType()));
auto resultElemTy = resultTy.getElementType(); auto resultElemTy = resultTy.getElementType();
auto newOutputTy = RankedTensorType::get( auto newOutputTy = RankedTensorType::get(
@ -1017,7 +1016,7 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Pow"); op, "Only ranked tensor types supported in TOSA Pow");
if (!selfTy.getElementType().isa<mlir::FloatType>()) if (!isa<mlir::FloatType>(selfTy.getElementType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported"); op, "Only floating-point datatype legalization supported");
@ -1624,9 +1623,9 @@ public:
rewriter.replaceOpWithNewOp<tensor::CastOp>( rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, op,
OpConversionPattern<AtenOpT>::getTypeConverter() cast<RankedTensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template cast<RankedTensorType>(), op.getType())),
output); output);
return success(); return success();
@ -1800,9 +1799,9 @@ public:
rewriter.replaceOpWithNewOp<tensor::CastOp>( rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, op,
OpConversionPattern<AtenOpT>::getTypeConverter() cast<RankedTensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template cast<RankedTensorType>(), op.getType())),
matmulPlusBias); matmulPlusBias);
return success(); return success();
@ -1823,7 +1822,7 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Rsub"); op, "Only ranked tensor types supported in TOSA Rsub");
if (!selfTy.getElementType().isa<mlir::FloatType>()) if (!isa<mlir::FloatType>(selfTy.getElementType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported"); op, "Only floating-point datatype legalization supported");
@ -1869,9 +1868,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
auto inputTy = cast<RankedTensorType>(input.getType()); auto inputTy = cast<RankedTensorType>(input.getType());
auto weightTy = cast<RankedTensorType>(weight.getType()); auto weightTy = cast<RankedTensorType>(weight.getType());
auto outputTy = getTypeConverter() auto outputTy =
->convertType(op.getType()) cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
.template cast<RankedTensorType>();
if (!inputTy || !weightTy || !outputTy) if (!inputTy || !weightTy || !outputTy)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -2208,7 +2206,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
// Note: cudnn_enabled is not handled. // Note: cudnn_enabled is not handled.
// FIXME: Handle training and momentum. // FIXME: Handle training and momentum.
if (op.getMomentum().getType().isa<Torch::NoneType>()) if (isa<Torch::NoneType>(op.getMomentum().getType()))
return rewriter.notifyMatchFailure(op, "Unsupported None for momentum"); return rewriter.notifyMatchFailure(op, "Unsupported None for momentum");
auto meanType = dyn_cast<TensorType>(adaptor.getRunningMean().getType()); auto meanType = dyn_cast<TensorType>(adaptor.getRunningMean().getType());
@ -2312,9 +2310,9 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
// Note: cudnn_enabled is not handled. // Note: cudnn_enabled is not handled.
// FIXME: Handle the None cases for the optional parameters. // FIXME: Handle the None cases for the optional parameters.
if (adaptor.getWeight().getType().isa<Torch::NoneType>()) if (isa<Torch::NoneType>(adaptor.getWeight().getType()))
return rewriter.notifyMatchFailure(op, "Unsupported None for weight"); return rewriter.notifyMatchFailure(op, "Unsupported None for weight");
if (adaptor.getBias().getType().isa<Torch::NoneType>()) if (isa<Torch::NoneType>(adaptor.getBias().getType()))
return rewriter.notifyMatchFailure(op, "Unsupported None for bias"); return rewriter.notifyMatchFailure(op, "Unsupported None for bias");
auto weightType = cast<RankedTensorType>(adaptor.getWeight().getType()); auto weightType = cast<RankedTensorType>(adaptor.getWeight().getType());
@ -2453,9 +2451,8 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
ValueTensorLiteralOp op, OpAdaptor adaptor, ValueTensorLiteralOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto outputTy = getTypeConverter() auto outputTy =
->convertType(op.getType()) cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
.template cast<RankedTensorType>();
// Tensors with integer types need to be converted to signless integer // Tensors with integer types need to be converted to signless integer
// element type. All tensors with element types other than integer can reuse // element type. All tensors with element types other than integer can reuse
@ -3122,7 +3119,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
cast<RankedTensorType>(typeConverter->convertType(op.getType())); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
auto indicesType = dyn_cast<RankedTensorType>(indices.getType()); auto indicesType = dyn_cast<RankedTensorType>(indices.getType());
if (!indicesType || !indicesType.getElementType().isa<IntegerType>()) if (!indicesType || !isa<IntegerType>(indicesType.getElementType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Indices must be of integer tensor type"); op, "Indices must be of integer tensor type");
@ -3632,11 +3629,11 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
auto indexTorch = tensorsTorchType[i]; auto indexTorch = tensorsTorchType[i];
// TODO add support for none index other than i==0, like (index0, None) // TODO add support for none index other than i==0, like (index0, None)
// (None, index1) // (None, index1)
if (i == 0 && indexTorch.getType().isa<Torch::NoneType>()) { if (i == 0 && isa<Torch::NoneType>(indexTorch.getType())) {
// convert None to [0,0,0] // convert None to [0,0,0]
auto indexNext = indexTensors[i + 1]; auto indexNext = indexTensors[i + 1];
auto indexNextTorch = tensorsTorchType[i + 1]; auto indexNextTorch = tensorsTorchType[i + 1];
if (indexNextTorch.getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(indexNextTorch.getType())) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Multiple None index is not support for now."); op, "Multiple None index is not support for now.");
} }
@ -3963,8 +3960,8 @@ LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
if (!selfType.hasStaticShape() || !otherType.hasStaticShape()) if (!selfType.hasStaticShape() || !otherType.hasStaticShape())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported"); op, "Only tensor types with static shape are supported");
if (!selfType.getElementType().isa<mlir::FloatType>() || if (!isa<mlir::FloatType>(selfType.getElementType()) ||
!otherType.getElementType().isa<mlir::FloatType>()) { !isa<mlir::FloatType>(otherType.getElementType())) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: only FP element type is supported"); op, "unimplemented: only FP element type is supported");
} }
@ -4058,9 +4055,8 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
const TypeConverter *typeConverter = this->getTypeConverter(); const TypeConverter *typeConverter = this->getTypeConverter();
RankedTensorType resultType = RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()) typeConverter->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
// At this point all tensors should have value semantics, and hence the // At this point all tensors should have value semantics, and hence the
// `layout` check can be ignored. // `layout` check can be ignored.
@ -4068,7 +4064,7 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
// TODO: Add support for pin_memory features. // TODO: Add support for pin_memory features.
// The pin_memory should be either `False` or `none`. // The pin_memory should be either `False` or `none`.
bool pinMemory; bool pinMemory;
if (!op.getPinMemory().getType().isa<Torch::NoneType>() && if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
pinMemory)) { pinMemory)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -4162,10 +4158,10 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
}; };
const auto isIntType = const auto isIntType =
resultType.getElementType().dyn_cast_or_null<mlir::IntegerType>(); dyn_cast_or_null<mlir::IntegerType>(resultType.getElementType());
const auto isDoubleType = const auto isDoubleType =
resultType.getElementType().dyn_cast_or_null<mlir::FloatType>(); dyn_cast_or_null<mlir::FloatType>(resultType.getElementType());
auto maybeResult = [&]() -> std::optional<Value> { auto maybeResult = [&]() -> std::optional<Value> {
// Integer output type, and start / end / range are all integers. // Integer output type, and start / end / range are all integers.
@ -4218,9 +4214,8 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
const TypeConverter *typeConverter = this->getTypeConverter(); const TypeConverter *typeConverter = this->getTypeConverter();
RankedTensorType resultType = RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()) typeConverter->convertType(op->getResult(0).getType()));
.cast<RankedTensorType>();
// Only supports integer operand type, because for the floating point operand // Only supports integer operand type, because for the floating point operand
// type result tensor has to be of type `f64` which is not supported in the // type result tensor has to be of type `f64` which is not supported in the
@ -4323,7 +4318,7 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
} }
// Only `none`, `contiguous` and `preserve` memory_format is supported. // Only `none`, `contiguous` and `preserve` memory_format is supported.
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getMemoryFormat().getType())) {
int64_t memoryFormat; int64_t memoryFormat;
if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -4336,9 +4331,8 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
"memory_format is supported"); "memory_format is supported");
} }
auto resultTy = getTypeConverter() auto resultTy = cast<RankedTensorType>(
->convertType(op.getResult().getType()) getTypeConverter()->convertType(op.getResult().getType()));
.cast<RankedTensorType>();
Value result; Value result;
if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.getSelf(), if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.getSelf(),
@ -4779,9 +4773,9 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter() auto outType = dyn_cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template dyn_cast<TensorType>(); op.getType()));
if (!outType) if (!outType)
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
@ -4841,9 +4835,9 @@ public:
LogicalResult LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter() auto outType = dyn_cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template dyn_cast<TensorType>(); op.getType()));
if (!outType || !outType.hasStaticShape()) if (!outType || !outType.hasStaticShape())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -4875,9 +4869,9 @@ public:
LogicalResult LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter() auto outType = dyn_cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template dyn_cast<TensorType>(); op.getType()));
if (!outType || !outType.hasStaticShape()) if (!outType || !outType.hasStaticShape())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -4947,9 +4941,9 @@ public:
"unimplemented: only contiguous and channels last memory " "unimplemented: only contiguous and channels last memory "
"format is supported"); "format is supported");
} }
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter() auto outType = dyn_cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template dyn_cast<TensorType>(); op.getType()));
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, adaptor.getSelf()); rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, adaptor.getSelf());
return success(); return success();
@ -5077,8 +5071,8 @@ LogicalResult ConvertAtenOp<AtenSqrtOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Only Tensor types supported in TOSA"); "Only Tensor types supported in TOSA");
auto resultType = typeConverter->convertType(op.getType()) auto resultType =
.template cast<RankedTensorType>(); cast<RankedTensorType>(typeConverter->convertType(op.getType()));
auto elementType = resultType.getElementType(); auto elementType = resultType.getElementType();
if (isa<mlir::IntegerType>(selfTy.getElementType())) { if (isa<mlir::IntegerType>(selfTy.getElementType())) {

View File

@ -813,9 +813,9 @@ convertReduceProdOp(PatternRewriter &rewriter, Operation *op,
return std::nullopt; return std::nullopt;
bool input_is_qtype = bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>(); isa<mlir::quant::UniformQuantizedType>(input_type.getElementType());
bool output_is_qtype = bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>(); isa<mlir::quant::UniformQuantizedType>(output_type.getElementType());
if (input_is_qtype || output_is_qtype) { if (input_is_qtype || output_is_qtype) {
op->emitOpError("ConvertReduceProdOp: input/output tensor should " op->emitOpError("ConvertReduceProdOp: input/output tensor should "
@ -839,9 +839,9 @@ convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
return std::nullopt; return std::nullopt;
bool input_is_qtype = bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>(); isa<mlir::quant::UniformQuantizedType>(input_type.getElementType());
bool output_is_qtype = bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>(); isa<mlir::quant::UniformQuantizedType>(output_type.getElementType());
if (input_is_qtype != output_is_qtype) { if (input_is_qtype != output_is_qtype) {
op->emitOpError("ConvertReduceSumOp: input/output tensor should " op->emitOpError("ConvertReduceSumOp: input/output tensor should "
@ -894,9 +894,9 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
return std::nullopt; return std::nullopt;
bool input_is_qtype = bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>(); isa<mlir::quant::UniformQuantizedType>(input_type.getElementType());
bool output_is_qtype = bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>(); isa<mlir::quant::UniformQuantizedType>(output_type.getElementType());
if (input_is_qtype != output_is_qtype) { if (input_is_qtype != output_is_qtype) {
op->emitOpError("ConvertReduceSumOp: input/output tensor should " op->emitOpError("ConvertReduceSumOp: input/output tensor should "
@ -905,7 +905,7 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
} }
// Only supports float type mean() if it's non-quantized // Only supports float type mean() if it's non-quantized
if (!input_is_qtype && !output_type.getElementType().isa<mlir::FloatType>()) { if (!input_is_qtype && !isa<mlir::FloatType>(output_type.getElementType())) {
op->emitWarning( op->emitWarning(
"Failed convertReduceMean: input unquantized type but output element " "Failed convertReduceMean: input unquantized type but output element "
"not FloatType!"); "not FloatType!");

View File

@ -31,7 +31,7 @@ LogicalResult verifyLinalgCompatibleTypes(Operation *op,
return false; return false;
auto tensor = dyn_cast<ValueTensorType>(type); auto tensor = dyn_cast<ValueTensorType>(type);
return !tensor || return !tensor ||
tensor.toBuiltinTensor().dyn_cast_or_null<RankedTensorType>(); dyn_cast_or_null<RankedTensorType>(tensor.toBuiltinTensor());
}; };
bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) && bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) &&
@ -66,7 +66,7 @@ Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
// Generate IR: assert(dim >= 0 && dim < inputRank) // Generate IR: assert(dim >= 0 && dim < inputRank)
void assertIsValidDim(OpBuilder &b, Location loc, Value dim, Value inputRank) { void assertIsValidDim(OpBuilder &b, Location loc, Value dim, Value inputRank) {
assert(dim.getType().isa<IntegerType>() && assert(isa<IntegerType>(dim.getType()) &&
"dim arg of assertIsValidDim must be integer type"); "dim arg of assertIsValidDim must be integer type");
Value cst0 = Value cst0 =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType())); b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
@ -139,12 +139,12 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
} }
Value castIntToIndex(OpBuilder &b, Location loc, Value v) { Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
assert(v.getType().isa<IntegerType>() && "must be called with integer type"); assert(isa<IntegerType>(v.getType()) && "must be called with integer type");
return b.create<arith::IndexCastOp>(loc, b.getIndexType(), v); return b.create<arith::IndexCastOp>(loc, b.getIndexType(), v);
} }
Value castIndexToInt64(OpBuilder &b, Location loc, Value idx) { Value castIndexToInt64(OpBuilder &b, Location loc, Value idx) {
assert(idx.getType().isa<IndexType>() && "must be called with integer type"); assert(isa<IndexType>(idx.getType()) && "must be called with integer type");
return b.create<arith::IndexCastOp>(loc, b.getI64Type(), idx); return b.create<arith::IndexCastOp>(loc, b.getI64Type(), idx);
} }
@ -375,7 +375,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
Value torchOptionalInt, Value builtinInt, Value torchOptionalInt, Value builtinInt,
Value defaultValue, Value dimSize) { Value defaultValue, Value dimSize) {
if (torchOptionalInt.getType().isa<Torch::NoneType>()) if (isa<Torch::NoneType>(torchOptionalInt.getType()))
return defaultValue; return defaultValue;
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize); auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
Value positiveDim = Value positiveDim =

View File

@ -149,14 +149,12 @@ static Value getScalarIntValue(Value input, Location loc,
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) { if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
if (inputDtype.isInteger(64)) { if (inputDtype.isInteger(64)) {
auto val = valueTensorLiteralOp.getValue() auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue())
.cast<DenseIntElementsAttr>()
.getSplatValue<int64_t>(); .getSplatValue<int64_t>();
return rewriter.create<Torch::ConstantIntOp>( return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val)); loc, rewriter.getI64IntegerAttr(val));
} else { } else {
auto val = valueTensorLiteralOp.getValue() auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue())
.cast<DenseIntElementsAttr>()
.getSplatValue<bool>(); .getSplatValue<bool>();
return rewriter.create<Torch::ConstantIntOp>( return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val)); loc, rewriter.getI64IntegerAttr(val));
@ -191,8 +189,7 @@ static Value getScalarFloatValue(Value input, Location loc,
return nullptr; return nullptr;
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) { if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
auto val = valueTensorLiteralOp.getValue() auto val = cast<DenseFPElementsAttr>(valueTensorLiteralOp.getValue())
.cast<DenseFPElementsAttr>()
.getSplatValue<FloatAttr>() .getSplatValue<FloatAttr>()
.getValueAsDouble(); .getValueAsDouble();
return rewriter.create<Torch::ConstantFloatOp>( return rewriter.create<Torch::ConstantFloatOp>(
@ -1946,7 +1943,7 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) {
auto resultType = dyn_cast<ValueTensorType>(getType()); auto resultType = dyn_cast<ValueTensorType>(getType());
if (resultType && resultType.hasDtype() && if (resultType && resultType.hasDtype() &&
resultType.getDtype().isa<mlir::IntegerType>()) { isa<mlir::IntegerType>(resultType.getDtype())) {
return getSelf(); return getSelf();
} }
return {}; return {};
@ -2136,7 +2133,7 @@ traceKnownSizeTensorType(Value value, std::optional<int64_t> dim) {
// Limit the loop count to 6 to avoid indefinite compilation times from // Limit the loop count to 6 to avoid indefinite compilation times from
// unbounded IR traversals. // unbounded IR traversals.
for (auto idx = 0; idx < 6; ++idx) { for (auto idx = 0; idx < 6; ++idx) {
if (!value || !value.getType().isa<BaseTensorType>()) if (!value || !isa<BaseTensorType>(value.getType()))
return failure(); return failure();
auto tensorType = cast<BaseTensorType>(value.getType()); auto tensorType = cast<BaseTensorType>(value.getType());
@ -2518,7 +2515,7 @@ OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) {
OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
// Constant fold int -> float conversion. // Constant fold int -> float conversion.
if (auto integerAttr = adaptor.getA().dyn_cast_or_null<IntegerAttr>()) { if (auto integerAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getA())) {
return FloatAttr::get( return FloatAttr::get(
mlir::Float64Type::get(getContext()), mlir::Float64Type::get(getContext()),
static_cast<double>(integerAttr.getValue().getSExtValue())); static_cast<double>(integerAttr.getValue().getSExtValue()));
@ -2535,7 +2532,7 @@ OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) {
// Constant fold float -> int conversion. // Constant fold float -> int conversion.
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) { if (auto floatAttr = dyn_cast_or_null<FloatAttr>(adaptor.getA())) {
return IntegerAttr::get( return IntegerAttr::get(
mlir::IntegerType::get(getContext(), 64), mlir::IntegerType::get(getContext(), 64),
static_cast<int64_t>(floatAttr.getValue().convertToDouble())); static_cast<int64_t>(floatAttr.getValue().convertToDouble()));
@ -2549,7 +2546,7 @@ OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) {
OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) {
// Constant fold float -> int conversion. // Constant fold float -> int conversion.
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) { if (auto floatAttr = dyn_cast_or_null<FloatAttr>(adaptor.getA())) {
return IntegerAttr::get( return IntegerAttr::get(
mlir::IntegerType::get(getContext(), 64), mlir::IntegerType::get(getContext(), 64),
static_cast<long>(floatAttr.getValue().convertToDouble())); static_cast<long>(floatAttr.getValue().convertToDouble()));
@ -2695,9 +2692,8 @@ LogicalResult NonValueTensorLiteralOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands, MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) { SmallVectorImpl<Type> &inferredReturnTypes) {
auto attr = properties.as<Properties *>() auto attr =
->getValue() dyn_cast_or_null<ElementsAttr>(properties.as<Properties *>()->getValue());
.dyn_cast_or_null<ElementsAttr>();
if (!attr) if (!attr)
return failure(); return failure();
RankedTensorType tensorType = cast<RankedTensorType>(attr.getType()); RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
@ -2723,10 +2719,10 @@ static bool areSizesAndDtypesCompatible(BaseTensorType a, BaseTensorType b) {
bool NonValueTensorLiteralOp::isCompatibleReturnTypes(TypeRange inferred, bool NonValueTensorLiteralOp::isCompatibleReturnTypes(TypeRange inferred,
TypeRange actual) { TypeRange actual) {
if (!actual[0].isa<BaseTensorType>()) if (!isa<BaseTensorType>(actual[0]))
return false; return false;
return areSizesAndDtypesCompatible(inferred[0].cast<BaseTensorType>(), return areSizesAndDtypesCompatible(cast<BaseTensorType>(inferred[0]),
actual[0].cast<BaseTensorType>()); cast<BaseTensorType>(actual[0]));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2737,9 +2733,8 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands, MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) { SmallVectorImpl<Type> &inferredReturnTypes) {
auto attr = properties.as<Properties *>() auto attr =
->getValue() dyn_cast_or_null<ElementsAttr>(properties.as<Properties *>()->getValue());
.dyn_cast_or_null<ElementsAttr>();
if (!attr) if (!attr)
return failure(); return failure();
RankedTensorType tensorType = cast<RankedTensorType>(attr.getType()); RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
@ -2760,8 +2755,8 @@ OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) {
bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs, bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs,
mlir::TypeRange outputs) { mlir::TypeRange outputs) {
return areSizesAndDtypesCompatible(inputs[0].cast<BaseTensorType>(), return areSizesAndDtypesCompatible(cast<BaseTensorType>(inputs[0]),
outputs[0].cast<BaseTensorType>()); cast<BaseTensorType>(outputs[0]));
} }
void TensorStaticInfoCastOp::getCanonicalizationPatterns( void TensorStaticInfoCastOp::getCanonicalizationPatterns(
@ -3072,7 +3067,7 @@ OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) {
if (!operandType) if (!operandType)
return nullptr; return nullptr;
if (operandType.hasDtype()) { if (operandType.hasDtype()) {
bool isFloatType = operandType.getDtype().isa<mlir::FloatType>(); bool isFloatType = isa<mlir::FloatType>(operandType.getDtype());
return IntegerAttr::get(IntegerType::get(getContext(), 1), isFloatType); return IntegerAttr::get(IntegerType::get(getContext(), 1), isFloatType);
} }
// doesn't has dtype // doesn't has dtype
@ -3130,12 +3125,12 @@ void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
int64_t start; int64_t start;
int64_t end; int64_t end;
int64_t step; int64_t step;
if (op.getStart().getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(op.getStart().getType())) {
start = 0; start = 0;
} else if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) { } else if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) {
return failure(); return failure();
} }
if (op.getEnd().getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(op.getEnd().getType())) {
end = listElements.size(); end = listElements.size();
} else if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) { } else if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) {
return failure(); return failure();
@ -3228,7 +3223,7 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// things. // things.
Value replacement = tupleConstruct.getElements()[i]; Value replacement = tupleConstruct.getElements()[i];
if (replacement.getType() != op.getType()) { if (replacement.getType() != op.getType()) {
if (op.getType().isa<BaseTensorType>()) { if (isa<BaseTensorType>(op.getType())) {
replacement = rewriter.create<Torch::TensorStaticInfoCastOp>( replacement = rewriter.create<Torch::TensorStaticInfoCastOp>(
op.getLoc(), op.getType(), replacement); op.getLoc(), op.getType(), replacement);
} else { } else {
@ -3384,8 +3379,8 @@ using BinaryIntOperatorFn = std::function<int64_t(int64_t, int64_t)>;
static OpFoldResult static OpFoldResult
atenBinaryIntOperatorFoldHelper(ArrayRef<Attribute> operands, atenBinaryIntOperatorFoldHelper(ArrayRef<Attribute> operands,
BinaryIntOperatorFn f) { BinaryIntOperatorFn f) {
auto intLhs = operands[0].dyn_cast_or_null<IntegerAttr>(); auto intLhs = dyn_cast_or_null<IntegerAttr>(operands[0]);
auto intRhs = operands[1].dyn_cast_or_null<IntegerAttr>(); auto intRhs = dyn_cast_or_null<IntegerAttr>(operands[1]);
if (!intLhs || !intRhs) { if (!intLhs || !intRhs) {
return nullptr; return nullptr;
} }
@ -3711,7 +3706,7 @@ OpFoldResult AtenAddOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
} }
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) { if (isa<IntegerAttr>(adaptor.getA()) && isa<IntegerAttr>(adaptor.getB())) {
return atenBinaryIntOperatorFoldHelper( return atenBinaryIntOperatorFoldHelper(
adaptor.getOperands(), adaptor.getOperands(),
[](int64_t a, int64_t b) -> int64_t { return a + b; }); [](int64_t a, int64_t b) -> int64_t { return a + b; });
@ -3730,7 +3725,7 @@ OpFoldResult AtenMulOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
} }
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) { if (isa<IntegerAttr>(adaptor.getA()) && isa<IntegerAttr>(adaptor.getB())) {
return atenBinaryIntOperatorFoldHelper( return atenBinaryIntOperatorFoldHelper(
adaptor.getOperands(), adaptor.getOperands(),
[](int64_t a, int64_t b) -> int64_t { return a * b; }); [](int64_t a, int64_t b) -> int64_t { return a * b; });
@ -3749,7 +3744,7 @@ OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
} }
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) { if (isa<IntegerAttr>(adaptor.getA()) && isa<IntegerAttr>(adaptor.getB())) {
return atenBinaryIntOperatorFoldHelper( return atenBinaryIntOperatorFoldHelper(
adaptor.getOperands(), adaptor.getOperands(),
[](int64_t a, int64_t b) -> int64_t { return a - b; }); [](int64_t a, int64_t b) -> int64_t { return a - b; });
@ -3806,7 +3801,7 @@ OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA()) { if (!adaptor.getA()) {
return nullptr; return nullptr;
} }
auto floatValue = adaptor.getA().dyn_cast_or_null<FloatAttr>(); auto floatValue = dyn_cast_or_null<FloatAttr>(adaptor.getA());
if (!floatValue) { if (!floatValue) {
return nullptr; return nullptr;
} }
@ -3834,7 +3829,7 @@ OpFoldResult AtenNegFloatOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA()) { if (!adaptor.getA()) {
return nullptr; return nullptr;
} }
auto value = adaptor.getA().dyn_cast_or_null<FloatAttr>(); auto value = dyn_cast_or_null<FloatAttr>(adaptor.getA());
if (!value) { if (!value) {
return nullptr; return nullptr;
} }
@ -4487,8 +4482,8 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
if (getA() == getB()) if (getA() == getB())
return getA(); return getA();
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>(); auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>(); auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
if (!lhs || !rhs) if (!lhs || !rhs)
return nullptr; return nullptr;
// Torch semantics are that !torch.int is 64-bit signed. // Torch semantics are that !torch.int is 64-bit signed.
@ -4556,8 +4551,8 @@ OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) {
if (getA() == getB()) if (getA() == getB())
return getA(); return getA();
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>(); auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>(); auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
if (!lhs || !rhs) if (!lhs || !rhs)
return nullptr; return nullptr;
// Torch semantics are that !torch.int is 64-bit signed. // Torch semantics are that !torch.int is 64-bit signed.
@ -4644,8 +4639,8 @@ LogicalResult AtenNormScalarOp::verify() {
// Check if dtype is one of those supported by norm operation. // Check if dtype is one of those supported by norm operation.
// ComplexType will match any torch complex types, but each float must be // ComplexType will match any torch complex types, but each float must be
// checked individually. // checked individually.
if (!inTensorDtype.isa<mlir::ComplexType, mlir::Float16Type, if (!isa<mlir::ComplexType, mlir::Float16Type, mlir::Float32Type,
mlir::Float32Type, mlir::Float64Type>()) { mlir::Float64Type>(inTensorDtype)) {
return emitOpError( return emitOpError(
"expected a float or complex type for input tensor, but got ") "expected a float or complex type for input tensor, but got ")
<< inTensorDtype; << inTensorDtype;

View File

@ -190,8 +190,8 @@ static bool isValidTorchDtype(Type dtype) {
// Builtin floating point types. // Builtin floating point types.
if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(dtype)) if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(dtype))
return true; return true;
if (dtype.isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType, if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType>()) Float8E4M3FNUZType, Float8E4M3B11FNUZType>(dtype))
return true; return true;
if (isa<Torch::StringType>(dtype)) if (isa<Torch::StringType>(dtype))
@ -228,9 +228,9 @@ Type BaseTensorType::getWithSizesAndDtypeFrom(BaseTensorType other) const {
Type BaseTensorType::getWithSizesAndDtype( Type BaseTensorType::getWithSizesAndDtype(
std::optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype) const { std::optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype) const {
if (isa<NonValueTensorType>()) if (mlir::isa<NonValueTensorType>(*this))
return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype); return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype);
if (isa<ValueTensorType>()) if (mlir::isa<ValueTensorType>(*this))
return ValueTensorType::get(getContext(), optionalSizes, optionalDtype); return ValueTensorType::get(getContext(), optionalSizes, optionalDtype);
llvm_unreachable("not a BaseTensorType!"); llvm_unreachable("not a BaseTensorType!");
} }
@ -248,9 +248,9 @@ Type BaseTensorType::getWithSizesAndDtypeAndSparsity(
} }
ValueTensorType BaseTensorType::getWithValueSemantics() const { ValueTensorType BaseTensorType::getWithValueSemantics() const {
if (auto tensor = dyn_cast<NonValueTensorType>()) if (auto tensor = mlir::dyn_cast<NonValueTensorType>(*this))
return tensor.getWithValueSemantics(); return tensor.getWithValueSemantics();
if (auto tensor = dyn_cast<ValueTensorType>()) if (auto tensor = mlir::dyn_cast<ValueTensorType>(*this))
return tensor; return tensor;
llvm_unreachable("not a BaseTensorType!"); llvm_unreachable("not a BaseTensorType!");
} }

View File

@ -110,7 +110,7 @@ public:
continue; continue;
auto it = typeBoundMap.find({call.getCallee(), operand.index()}); auto it = typeBoundMap.find({call.getCallee(), operand.index()});
if (it != typeBoundMap.end()) { if (it != typeBoundMap.end()) {
if (auto valueTensorType = it->second.dyn_cast<ValueTensorType>()) { if (auto valueTensorType = dyn_cast<ValueTensorType>(it->second)) {
newOperands.push_back(copyTensorToType( newOperands.push_back(copyTensorToType(
rewriter, call->getLoc(), valueTensorType, operand.value())); rewriter, call->getLoc(), valueTensorType, operand.value()));
continue; continue;
@ -215,11 +215,11 @@ static LogicalResult adjustCallingConventions(func::FuncOp func,
for (int i = 0, e = func.getNumArguments(); i != e; i++) { for (int i = 0, e = func.getNumArguments(); i != e; i++) {
if (func.getArgAttr(i, "torch.type_bound")) if (func.getArgAttr(i, "torch.type_bound"))
return false; return false;
if (func.getArgumentTypes()[i].isa<Torch::NoneType>()) if (isa<Torch::NoneType>(func.getArgumentTypes()[i]))
return false; return false;
} }
for (int i = 0, e = func.getNumResults(); i != e; i++) { for (int i = 0, e = func.getNumResults(); i != e; i++) {
if (func.getFunctionType().getResults()[i].isa<Torch::NoneType>()) if (isa<Torch::NoneType>(func.getFunctionType().getResults()[i]))
return false; return false;
} }
return true; return true;

View File

@ -38,7 +38,7 @@ static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
if (failed(resDtype)) if (failed(resDtype))
return false; return false;
return resDtype->isa<mlir::FloatType>(); return isa<mlir::FloatType>(*resDtype);
} }
// Helper function to compute the return type of the reduction function. // Helper function to compute the return type of the reduction function.
@ -99,19 +99,15 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
Operation *op, Value input, Value dim, Operation *op, Value input, Value dim,
bool keepDim) { bool keepDim) {
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim); Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
BaseTensorType valueType = BaseTensorType valueType = cast<BaseTensorType>(computeReductionType(
computeReductionType(rewriter, op, cast<BaseTensorType>(input.getType()), rewriter, op, cast<BaseTensorType>(input.getType()), dim, keepDim));
dim, keepDim)
.cast<BaseTensorType>();
if (!valueType) if (!valueType)
return nullptr; return nullptr;
BaseTensorType indexType = BaseTensorType indexType =
valueType cast<BaseTensorType>(valueType.getWithSizesAndDtype(
.getWithSizesAndDtype( !valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
!valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>() : llvm::ArrayRef(valueType.getSizes()),
: llvm::ArrayRef(valueType.getSizes()), IntegerType::get(op->getContext(), 64, IntegerType::Signed)));
IntegerType::get(op->getContext(), 64, IntegerType::Signed))
.cast<BaseTensorType>();
return rewriter return rewriter
.create<AtenMaxDimOp>(loc, valueType, indexType, input, dim, keepDimCst) .create<AtenMaxDimOp>(loc, valueType, indexType, input, dim, keepDimCst)
.getValues(); .getValues();
@ -1059,7 +1055,7 @@ public:
LogicalResult matchAndRewrite(AtenEyeMOp op, LogicalResult matchAndRewrite(AtenEyeMOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
auto outType = op.getType().dyn_cast<BaseTensorType>(); auto outType = dyn_cast<BaseTensorType>(op.getType());
if (!outType) if (!outType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported"); op, "Only tensor types input are currently supported");
@ -1659,11 +1655,9 @@ public:
unsigned inputRank = *maybeInputRank; unsigned inputRank = *maybeInputRank;
if (!indicesTensorType.hasSizes()) if (!indicesTensorType.hasSizes())
return failure(); return failure();
BaseTensorType valueTensorType = BaseTensorType valueTensorType = cast<BaseTensorType>(
inputType inputType.getWithSizesAndDtype(indicesTensorType.getOptionalSizes(),
.getWithSizesAndDtype(indicesTensorType.getOptionalSizes(), inputType.getOptionalDtype()));
inputType.getOptionalDtype())
.cast<BaseTensorType>();
// If the dim type is `NoneType` i.e. reduce along all the dimensions. // If the dim type is `NoneType` i.e. reduce along all the dimensions.
// `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so // `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so
@ -1671,10 +1665,8 @@ public:
// happens on the 0th dimension. // happens on the 0th dimension.
if (isa<Torch::NoneType>(dim.getType())) { if (isa<Torch::NoneType>(dim.getType())) {
BaseTensorType flattenType = BaseTensorType flattenType =
inputType cast<BaseTensorType>(inputType.getWithSizesAndDtype(
.getWithSizesAndDtype({kUnknownSize}, {kUnknownSize}, inputType.getOptionalDtype()));
inputType.getOptionalDtype())
.cast<BaseTensorType>();
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0)); dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value end = rewriter.create<ConstantIntOp>( Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputRank - 1)); loc, rewriter.getI64IntegerAttr(inputRank - 1));
@ -3003,7 +2995,7 @@ public:
bool dimIsNone = false; bool dimIsNone = false;
int64_t dim; int64_t dim;
Value dimValue = op.getDim(); Value dimValue = op.getDim();
if (dimValue.getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(dimValue.getType())) {
dimIsNone = true; dimIsNone = true;
dim = inputRank - 1; dim = inputRank - 1;
} else { } else {
@ -3887,10 +3879,9 @@ public:
gradOutputViewSizesInt[0] = kUnknownSize; gradOutputViewSizesInt[0] = kUnknownSize;
gradOutputViewSizesInt[1] = 1; gradOutputViewSizesInt[1] = 1;
BaseTensorType gradOutputTypeForView = BaseTensorType gradOutputTypeForView =
gradOutputTy cast<BaseTensorType>(gradOutputTy.getWithSizesAndDtype(
.getWithSizesAndDtype(llvm::ArrayRef(gradOutputViewSizesInt), llvm::ArrayRef(gradOutputViewSizesInt),
gradOutputTy.getOptionalDtype()) gradOutputTy.getOptionalDtype()));
.cast<BaseTensorType>();
Value gradOutputView = rewriter.create<Torch::AtenViewOp>( Value gradOutputView = rewriter.create<Torch::AtenViewOp>(
loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList); loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList);
@ -3918,10 +3909,9 @@ public:
} }
BaseTensorType gradWeightTy = BaseTensorType gradWeightTy =
inputTransposedTy cast<BaseTensorType>(inputTransposedTy.getWithSizesAndDtype(
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt), llvm::ArrayRef(gradWeightSizesInt),
inputTransposedTy.getOptionalDtype()) inputTransposedTy.getOptionalDtype()));
.cast<BaseTensorType>();
Value numGroup = rewriter.create<AtenSizeIntOp>(loc, input, cstZero); Value numGroup = rewriter.create<AtenSizeIntOp>(loc, input, cstZero);
gradWeight = rewriter.create<Torch::AtenConvolutionOp>( gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
@ -3937,10 +3927,9 @@ public:
for (unsigned i = 0; i < gradWeightTy.getSizes().size() - 2; i++) { for (unsigned i = 0; i < gradWeightTy.getSizes().size() - 2; i++) {
gradWeightSizesInt[i + 2] = weightSizes[i + 2]; gradWeightSizesInt[i + 2] = weightSizes[i + 2];
BaseTensorType gradWeightNarrowTy = BaseTensorType gradWeightNarrowTy =
gradWeightTy cast<BaseTensorType>(gradWeightTy.getWithSizesAndDtype(
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt), llvm::ArrayRef(gradWeightSizesInt),
gradWeightTy.getOptionalDtype()) gradWeightTy.getOptionalDtype()));
.cast<BaseTensorType>();
Value dim = rewriter.create<ConstantIntOp>( Value dim = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i + 2)); loc, rewriter.getI64IntegerAttr(i + 2));
@ -3970,10 +3959,9 @@ public:
gradWeightViewShapeValue); gradWeightViewShapeValue);
BaseTensorType gradWeightTypeForView = BaseTensorType gradWeightTypeForView =
gradWeightTy cast<BaseTensorType>(gradWeightTy.getWithSizesAndDtype(
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightViewShapeInt), llvm::ArrayRef(gradWeightViewShapeInt),
gradWeightTy.getOptionalDtype()) gradWeightTy.getOptionalDtype()));
.cast<BaseTensorType>();
gradWeight = rewriter.create<Torch::AtenViewOp>( gradWeight = rewriter.create<Torch::AtenViewOp>(
loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList); loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList);
@ -3986,10 +3974,9 @@ public:
gradWeightViewShapeInt[gradWeightDimsOrder[i]]); gradWeightViewShapeInt[gradWeightDimsOrder[i]]);
} }
BaseTensorType gradWeightTypeForMoveDim = BaseTensorType gradWeightTypeForMoveDim =
gradWeightTy cast<BaseTensorType>(gradWeightTy.getWithSizesAndDtype(
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightMoveDimShape), llvm::ArrayRef(gradWeightMoveDimShape),
gradWeightTy.getOptionalDtype()) gradWeightTy.getOptionalDtype()));
.cast<BaseTensorType>();
gradWeight = rewriter.create<AtenMovedimIntOp>( gradWeight = rewriter.create<AtenMovedimIntOp>(
loc, gradWeightTypeForMoveDim, gradWeight, /*source=*/cstZero, loc, gradWeightTypeForMoveDim, gradWeight, /*source=*/cstZero,
@ -4009,9 +3996,8 @@ public:
Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>( Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
loc, transposedType, gradOutput, cstZero, cstOne); loc, transposedType, gradOutput, cstZero, cstOne);
// Convolve input with grad_output. // Convolve input with grad_output.
if (failed( if (failed(getTransposedType(cast<BaseTensorType>(op.getResultTypes()[1]),
getTransposedType(op.getResultTypes()[1].cast<BaseTensorType>(), 0, 1, transposedType)))
0, 1, transposedType)))
return failure(); return failure();
gradWeight = rewriter.create<Torch::AtenConvolutionOp>( gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
loc, transposedType, inputTransposed, gradOutputTransposed, cstNone, loc, transposedType, inputTransposed, gradOutputTransposed, cstNone,
@ -4063,7 +4049,7 @@ public:
// TODO: Handle integer type operands. // TODO: Handle integer type operands.
auto inputType = cast<BaseTensorType>(input.getType()); auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) { if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype())) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: non-floating point dtype"); op, "unimplemented: non-floating point dtype");
} }
@ -4125,7 +4111,7 @@ public:
MLIRContext *context = op.getContext(); MLIRContext *context = op.getContext();
BaseTensorType inputType = cast<BaseTensorType>(input.getType()); BaseTensorType inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>() || if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype()) ||
!isNoneOrFloatDtype(context, dtype)) { !isNoneOrFloatDtype(context, dtype)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only floating-point type is supported"); op, "only floating-point type is supported");
@ -4133,7 +4119,7 @@ public:
SmallVector<Value> dimListElements; SmallVector<Value> dimListElements;
if (!getListConstructElements(dimList, dimListElements) && if (!getListConstructElements(dimList, dimListElements) &&
!dimList.getType().isa<Torch::NoneType>()) { !isa<Torch::NoneType>(dimList.getType())) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected `dim` to be `None` or constructed from list construct"); op, "expected `dim` to be `None` or constructed from list construct");
} }
@ -4215,7 +4201,7 @@ public:
return success(); return success();
} }
BaseTensorType inputType = cast<BaseTensorType>(input.getType()); BaseTensorType inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only support floating type input for training mode"); op, "only support floating type input for training mode");
Value noneVal = rewriter.create<ConstantNoneOp>(loc); Value noneVal = rewriter.create<ConstantNoneOp>(loc);
@ -4243,7 +4229,7 @@ public:
Value input = op.getInput(); Value input = op.getInput();
Value prob = op.getP(); Value prob = op.getP();
bool train = false; bool train = false;
if (!op.getTrain().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getTrain().getType())) {
if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) { if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "train must be a boolean constant or none"); op, "train must be a boolean constant or none");
@ -4263,7 +4249,7 @@ public:
return success(); return success();
} }
BaseTensorType inputType = cast<BaseTensorType>(input.getType()); BaseTensorType inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) { if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype())) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only support floating type input for training mode"); op, "only support floating type input for training mode");
} }
@ -4332,7 +4318,7 @@ public:
Value self = op.getSelf(); Value self = op.getSelf();
BaseTensorType inputTensorTy = cast<BaseTensorType>(self.getType()); BaseTensorType inputTensorTy = cast<BaseTensorType>(self.getType());
if (!inputTensorTy.hasDtype() || if (!inputTensorTy.hasDtype() ||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) { !isa<mlir::FloatType>(inputTensorTy.getDtype())) {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Only aten.std support floating type"); "Only aten.std support floating type");
} }
@ -4388,7 +4374,7 @@ public:
Value self = op.getSelf(); Value self = op.getSelf();
BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType()); BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType());
if (!inputTensorType.hasDtype() || if (!inputTensorType.hasDtype() ||
!inputTensorType.getDtype().isa<mlir::FloatType>()) { !isa<mlir::FloatType>(inputTensorType.getDtype())) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "aten.std.dim expects input tensor of floating-point type"); op, "aten.std.dim expects input tensor of floating-point type");
} }
@ -4413,7 +4399,7 @@ public:
Value self = op.getSelf(); Value self = op.getSelf();
BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType()); BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType());
if (!inputTensorType.hasDtype() || if (!inputTensorType.hasDtype() ||
!inputTensorType.getDtype().isa<mlir::FloatType>()) { !isa<mlir::FloatType>(inputTensorType.getDtype())) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, op,
"aten.std.correction expects input tensor of floating-point type"); "aten.std.correction expects input tensor of floating-point type");
@ -4506,7 +4492,7 @@ public:
Value input = op.getSelf(); Value input = op.getSelf();
Type resultType = op.getType(); Type resultType = op.getType();
auto inputType = cast<BaseTensorType>(input.getType()); auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) { if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype())) {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"only support floating-point type"); "only support floating-point type");
} }
@ -4547,7 +4533,7 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,
op, "can't decompose bernoulli like ops without sizes or dtype"); op, "can't decompose bernoulli like ops without sizes or dtype");
} }
// The `prob` is expected to be a float type tensor. // The `prob` is expected to be a float type tensor.
if (!probType.getDtype().isa<mlir::FloatType>()) { if (!isa<mlir::FloatType>(probType.getDtype())) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "probabilities must be a float type tensor"); op, "probabilities must be a float type tensor");
} }
@ -4582,7 +4568,7 @@ public:
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = op.getSelf(); Value input = op.getSelf();
if (!op.getGenerator().getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(op.getGenerator().getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default " op, "The generator has to be None because only global default "
"generator is supported"); "generator is supported");
@ -4640,7 +4626,7 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = op.getSelf(); Value input = op.getSelf();
Value prob = op.getP(); Value prob = op.getP();
if (!op.getGenerator().getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(op.getGenerator().getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default " op, "The generator has to be None because only global default "
"generator is supported"); "generator is supported");
@ -4665,7 +4651,7 @@ public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenExponentialOp op, LogicalResult matchAndRewrite(AtenExponentialOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
if (!op.getGenerator().getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(op.getGenerator().getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default " op, "The generator has to be None because only global default "
"generator is supported"); "generator is supported");
@ -4706,7 +4692,7 @@ public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNormalFunctionalOp op, LogicalResult matchAndRewrite(AtenNormalFunctionalOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
if (!op.getGenerator().getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(op.getGenerator().getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default " op, "The generator has to be None because only global default "
"generator is supported"); "generator is supported");
@ -4984,10 +4970,10 @@ class DecomposeAtenNativeLayerNormOp
Value weight = op.getWeight(); Value weight = op.getWeight();
Value bias = op.getBias(); Value bias = op.getBias();
if (!weight.getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(weight.getType())) {
out = rewriter.create<AtenMulTensorOp>(loc, out.getType(), out, weight); out = rewriter.create<AtenMulTensorOp>(loc, out.getType(), out, weight);
} }
if (!bias.getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(bias.getType())) {
out = out =
rewriter.create<AtenAddTensorOp>(loc, out.getType(), out, bias, one); rewriter.create<AtenAddTensorOp>(loc, out.getType(), out, bias, one);
} }
@ -5238,13 +5224,13 @@ class DecomposeAtenNativeGroupNormOp
loc, ListType::get(IntType::get(context)), viewShape); loc, ListType::get(IntType::get(context)), viewShape);
Value groupNormOutput = reshapedOutput; Value groupNormOutput = reshapedOutput;
if (!weight.getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(weight.getType())) {
auto weightReshaped = rewriter.create<AtenViewOp>( auto weightReshaped = rewriter.create<AtenViewOp>(
loc, baseType, weight, /*shape=*/viewShapeSizeList); loc, baseType, weight, /*shape=*/viewShapeSizeList);
groupNormOutput = rewriter.create<AtenMulTensorOp>( groupNormOutput = rewriter.create<AtenMulTensorOp>(
loc, inputType, groupNormOutput, weightReshaped); loc, inputType, groupNormOutput, weightReshaped);
} }
if (!bias.getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(bias.getType())) {
auto biasReshaped = rewriter.create<AtenViewOp>( auto biasReshaped = rewriter.create<AtenViewOp>(
loc, baseType, bias, /*shape=*/viewShapeSizeList); loc, baseType, bias, /*shape=*/viewShapeSizeList);
groupNormOutput = rewriter.create<AtenAddTensorOp>( groupNormOutput = rewriter.create<AtenAddTensorOp>(
@ -5297,8 +5283,8 @@ class DecomposeAtenNativeBatchNormOp
// In the inference mode, the `runningMean` and `runningVar` must not be // In the inference mode, the `runningMean` and `runningVar` must not be
// None. // None.
if (runningMean.getType().isa<Torch::NoneType>() || if (isa<Torch::NoneType>(runningMean.getType()) ||
runningVar.getType().isa<Torch::NoneType>()) isa<Torch::NoneType>(runningVar.getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "running stats must not be None in inference mode"); op, "running stats must not be None in inference mode");
@ -5354,7 +5340,7 @@ class DecomposeAtenNativeBatchNormOp
// 2. bias = bias.view(1, C, 1?, 1?, 1?) // 2. bias = bias.view(1, C, 1?, 1?, 1?)
// 3. output = normalizedInput * weight + bias // 3. output = normalizedInput * weight + bias
Value batchNormOutput = normalizedInput; Value batchNormOutput = normalizedInput;
if (!weight.getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(weight.getType())) {
// Rank of `weight` must be exactly 1. // Rank of `weight` must be exactly 1.
std::optional<unsigned> weightRank = getTensorRank(weight); std::optional<unsigned> weightRank = getTensorRank(weight);
if (!weightRank || *weightRank != 1) if (!weightRank || *weightRank != 1)
@ -5364,7 +5350,7 @@ class DecomposeAtenNativeBatchNormOp
batchNormOutput = rewriter.create<AtenMulTensorOp>( batchNormOutput = rewriter.create<AtenMulTensorOp>(
loc, batchNormOutput.getType(), batchNormOutput, weight); loc, batchNormOutput.getType(), batchNormOutput, weight);
} }
if (!bias.getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(bias.getType())) {
// Rank of `bias` must be exactly 1. // Rank of `bias` must be exactly 1.
std::optional<unsigned> biasRank = getTensorRank(bias); std::optional<unsigned> biasRank = getTensorRank(bias);
if (!biasRank || *biasRank != 1) if (!biasRank || *biasRank != 1)
@ -5444,7 +5430,7 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op, LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value dtype = op.getDtype(); Value dtype = op.getDtype();
if (dtype.getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(dtype.getType())) {
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType()); BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
if (!tensorType.hasDtype()) { if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -5518,7 +5504,7 @@ public:
return transposeWeight; return transposeWeight;
}; };
if (bias.getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(bias.getType())) {
auto weightRank = weightType.getSizes().size(); auto weightRank = weightType.getSizes().size();
if (weightRank > 2 || weightRank <= 0) if (weightRank > 2 || weightRank <= 0)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -5622,7 +5608,7 @@ public:
LogicalResult matchAndRewrite(AtenNewFullOp op, LogicalResult matchAndRewrite(AtenNewFullOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value dtype = op.getDtype(); Value dtype = op.getDtype();
if (dtype.getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(dtype.getType())) {
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType()); BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
if (!tensorType.hasDtype()) { if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -5718,7 +5704,7 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc()); Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
Value dtype = op.getDtype(); Value dtype = op.getDtype();
if (dtype.getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(dtype.getType())) {
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType()); BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
if (!tensorType.hasDtype()) { if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -5743,9 +5729,9 @@ class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value value = op.getValue(); Value value = op.getValue();
if (value.getType().isa<Torch::OptionalType>()) if (isa<Torch::OptionalType>(value.getType()))
return rewriter.notifyMatchFailure(op, "optional type not supported"); return rewriter.notifyMatchFailure(op, "optional type not supported");
if (value.getType().isa<Torch::NoneType>()) if (isa<Torch::NoneType>(value.getType()))
value = rewriter.create<Torch::ConstantFloatOp>( value = rewriter.create<Torch::ConstantFloatOp>(
op.getLoc(), rewriter.getF64FloatAttr(0)); op.getLoc(), rewriter.getF64FloatAttr(0));
@ -5765,7 +5751,7 @@ public:
LogicalResult matchAndRewrite(AtenToDtypeLayoutOp op, LogicalResult matchAndRewrite(AtenToDtypeLayoutOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
// TODO: Add support for pinMemory arg equal to `True`. // TODO: Add support for pinMemory arg equal to `True`.
if (!op.getPinMemory().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getPinMemory().getType())) {
bool pinMemory; bool pinMemory;
if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory))) if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -5776,7 +5762,7 @@ public:
} }
// TODO: Add support for device arg other than cpu. // TODO: Add support for device arg other than cpu.
if (!op.getDevice().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getDevice().getType())) {
std::string device; std::string device;
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -5788,7 +5774,7 @@ public:
// TODO: Add support for non-strided layout. // TODO: Add support for non-strided layout.
// torch.layout is by default strided i.e. 0. // torch.layout is by default strided i.e. 0.
if (!op.getLayout().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getLayout().getType())) {
int64_t tensorLayout; int64_t tensorLayout;
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -6254,7 +6240,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
Type newOutputType = outputTensorType.getWithSizesAndDtype( Type newOutputType = outputTensorType.getWithSizesAndDtype(
outputTensorType.getSizes(), rewriter.getF64Type()); outputTensorType.getSizes(), rewriter.getF64Type());
if (!inputTensorTy.hasDtype() || if (!inputTensorTy.hasDtype() ||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) { !isa<mlir::FloatType>(inputTensorTy.getDtype())) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "support floating-point type input only"); op, "support floating-point type input only");
} }
@ -6391,14 +6377,14 @@ public:
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
int64_t correctionValInt; int64_t correctionValInt;
double correctionValFloat = 1.0; double correctionValFloat = 1.0;
if (!op.getCorrection().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getCorrection().getType())) {
if (op.getCorrection().getType().isa<Torch::FloatType>()) { if (isa<Torch::FloatType>(op.getCorrection().getType())) {
if (!matchPattern(op.getCorrection(), if (!matchPattern(op.getCorrection(),
m_TorchConstantFloat(&correctionValFloat))) m_TorchConstantFloat(&correctionValFloat)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only support constant int or float correction value for " op, "Only support constant int or float correction value for "
"aten.var"); "aten.var");
} else if (op.getCorrection().getType().isa<Torch::IntType>()) { } else if (isa<Torch::IntType>(op.getCorrection().getType())) {
if (!matchPattern(op.getCorrection(), if (!matchPattern(op.getCorrection(),
m_TorchConstantInt(&correctionValInt))) m_TorchConstantInt(&correctionValInt)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -6525,11 +6511,9 @@ public:
if (!inputType.hasSizes()) if (!inputType.hasSizes())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Expected the input tensor to have sizes"); op, "Expected the input tensor to have sizes");
BaseTensorType subType = BaseTensorType subType = cast<BaseTensorType>(
inputType inputType.getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()),
.getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()), resultType.getOptionalDtype()));
resultType.getOptionalDtype())
.cast<BaseTensorType>();
Value sub = Value sub =
createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget()); createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget());
@ -6566,7 +6550,7 @@ public:
Location loc = op->getLoc(); Location loc = op->getLoc();
Value none = rewriter.create<Torch::ConstantNoneOp>(loc); Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value ord = op.getP(); Value ord = op.getP();
if (ord.getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(ord.getType())) {
ord = rewriter.create<Torch::ConstantFloatOp>( ord = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(2.0)); loc, rewriter.getF64FloatAttr(2.0));
} }
@ -6609,10 +6593,8 @@ public:
loc, rewriter.getF64FloatAttr((double)cstHigh)); loc, rewriter.getF64FloatAttr((double)cstHigh));
BaseTensorType floatResultType = BaseTensorType floatResultType =
resultTensorType cast<BaseTensorType>(resultTensorType.getWithSizesAndDtype(
.getWithSizesAndDtype(resultTensorType.getSizes(), resultTensorType.getSizes(), rewriter.getF32Type()));
rewriter.getF32Type())
.cast<BaseTensorType>();
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>( Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, floatResultType, op.getSize(), /*dtype=*/none, loc, floatResultType, op.getSize(), /*dtype=*/none,
/*layout=*/op.getLayout(), /*layout=*/op.getLayout(),
@ -6704,7 +6686,7 @@ public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimsVarOp op, LogicalResult matchAndRewrite(PrimsVarOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
if (!op.getOutputDtype().getType().isa<Torch::NoneType>()) if (!isa<Torch::NoneType>(op.getOutputDtype().getType()))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Unimplemented non-None dtype for prims::var op"); op, "Unimplemented non-None dtype for prims::var op");
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false); Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
@ -6816,7 +6798,7 @@ public:
LogicalResult matchAndRewrite(AtenRandnLikeOp op, LogicalResult matchAndRewrite(AtenRandnLikeOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
// Only `none`, `contiguous` and `preserve` memory_format is supported. // Only `none`, `contiguous` and `preserve` memory_format is supported.
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getMemoryFormat().getType())) {
int64_t memoryFormat; int64_t memoryFormat;
if (!matchPattern(op.getMemoryFormat(), if (!matchPattern(op.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat))) m_TorchConstantInt(&memoryFormat)))
@ -6913,8 +6895,8 @@ public:
op.getDevice(), op.getPinMemory()); op.getDevice(), op.getPinMemory());
// calculate (end - start) / (steps - 1) // calculate (end - start) / (steps - 1)
Value sub; Value sub;
if (op.getEnd().getType().isa<Torch::FloatType>() || if (isa<Torch::FloatType>(op.getEnd().getType()) ||
op.getStart().getType().isa<Torch::FloatType>()) { isa<Torch::FloatType>(op.getStart().getType())) {
sub = rewriter.create<AtenSubOp>(loc, Torch::FloatType::get(context), sub = rewriter.create<AtenSubOp>(loc, Torch::FloatType::get(context),
op.getEnd(), op.getStart()); op.getEnd(), op.getStart());
} else { } else {
@ -6930,7 +6912,7 @@ public:
} }
// to dtype // to dtype
Value result; Value result;
if (!op.getDtype().getType().isa<Torch::NoneType>()) { if (!isa<Torch::NoneType>(op.getDtype().getType())) {
result = rewriter.create<AtenToDtypeOp>( result = rewriter.create<AtenToDtypeOp>(
loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal, loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal,
/*copy=*/falseVal, /*memory_format=*/none); /*copy=*/falseVal, /*memory_format=*/none);
@ -7344,11 +7326,8 @@ public:
auto selfType = cast<BaseTensorType>(self.getType()); auto selfType = cast<BaseTensorType>(self.getType());
auto indexType = cast<BaseTensorType>(index.getType()); auto indexType = cast<BaseTensorType>(index.getType());
BaseTensorType srcType = BaseTensorType srcType = cast<BaseTensorType>(selfType.getWithSizesAndDtype(
selfType indexType.getOptionalSizes(), selfType.getOptionalDtype()));
.getWithSizesAndDtype(indexType.getOptionalSizes(),
selfType.getOptionalDtype())
.cast<BaseTensorType>();
Value src = Value src =
createInitTensor(rewriter, loc, srcType, op.getValue(), sizeList); createInitTensor(rewriter, loc, srcType, op.getValue(), sizeList);
rewriter.replaceOpWithNewOp<AtenScatterSrcOp>(op, op.getType(), self, rewriter.replaceOpWithNewOp<AtenScatterSrcOp>(op, op.getType(), self,
@ -7372,7 +7351,7 @@ public:
"expected result type to have dtype"); "expected result type to have dtype");
} }
// TODO: support complex type in future. // TODO: support complex type in future.
if (outType.getDtype().isa<mlir::ComplexType>()) { if (isa<mlir::ComplexType>(outType.getDtype())) {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"doesn't support complex type now"); "doesn't support complex type now");
} }
@ -7488,7 +7467,7 @@ static FailureOr<Value> createNewIndices(Operation *op,
Location loc = op->getLoc(); Location loc = op->getLoc();
MLIRContext *context = op->getContext(); MLIRContext *context = op->getContext();
auto inputType = input.getType().cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasSizes()) { if (!inputType.hasSizes()) {
return failure(); return failure();
} }
@ -7497,7 +7476,7 @@ static FailureOr<Value> createNewIndices(Operation *op,
int64_t maxIndexRank = 0; int64_t maxIndexRank = 0;
for (auto index : oldIndices) { for (auto index : oldIndices) {
auto indexType = index.getType().dyn_cast<BaseTensorType>(); auto indexType = dyn_cast<BaseTensorType>(index.getType());
if (!indexType) // None index if (!indexType) // None index
continue; continue;
if (!indexType.hasSizes()) if (!indexType.hasSizes())
@ -7586,15 +7565,13 @@ public:
int64_t inputRank = inputSizes.size(); int64_t inputRank = inputSizes.size();
auto isTensor = [](Value v) { auto isTensor = [](Value v) {
return v.getType().isa<Torch::BaseTensorType>(); return isa<Torch::BaseTensorType>(v.getType());
}; };
// directly replace aten.Index.Tensor with aten.index.Tensor_hacked_twin // directly replace aten.Index.Tensor with aten.index.Tensor_hacked_twin
if (llvm::all_of(indices, isTensor)) { if (llvm::all_of(indices, isTensor)) {
// By default, we regard the first index type as the list element type. // By default, we regard the first index type as the list element type.
auto indexElemType = indices[0] auto indexElemType = cast<BaseTensorType>(indices[0].getType())
.getType()
.template cast<BaseTensorType>()
.getWithSizesAndDtype(std::nullopt, nullptr); .getWithSizesAndDtype(std::nullopt, nullptr);
auto newIndices = rewriter.create<PrimListConstructOp>( auto newIndices = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(indexElemType), indices); loc, Torch::ListType::get(indexElemType), indices);
@ -7684,7 +7661,7 @@ public:
"failed to get elements of `indices`"); "failed to get elements of `indices`");
auto input = op.getSelf(); auto input = op.getSelf();
auto inputType = input.getType().template cast<BaseTensorType>(); auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasSizes()) { if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only input with shape information is supported"); op, "only input with shape information is supported");
@ -7693,15 +7670,13 @@ public:
int64_t inputRank = inputSizes.size(); int64_t inputRank = inputSizes.size();
auto isTensor = [](Value v) { auto isTensor = [](Value v) {
return v.getType().isa<Torch::BaseTensorType>(); return isa<Torch::BaseTensorType>(v.getType());
}; };
// directly replace current op with aten.index_put.hacked_twin // directly replace current op with aten.index_put.hacked_twin
if (llvm::all_of(indices, isTensor)) { if (llvm::all_of(indices, isTensor)) {
// By default, we regard the first index type as the list element type. // By default, we regard the first index type as the list element type.
auto indexElemType = indices[0] auto indexElemType = cast<BaseTensorType>(indices[0].getType())
.getType()
.template cast<BaseTensorType>()
.getWithSizesAndDtype(std::nullopt, nullptr); .getWithSizesAndDtype(std::nullopt, nullptr);
auto newIndex = rewriter.create<PrimListConstructOp>( auto newIndex = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(indexElemType), indices); loc, Torch::ListType::get(indexElemType), indices);
@ -7831,7 +7806,7 @@ public:
// default ord value is 2 for vector_norm // default ord value is 2 for vector_norm
auto ord = op.getOrd(); auto ord = op.getOrd();
if (ord.getType().isa<Torch::NoneType>()) { if (isa<Torch::NoneType>(ord.getType())) {
ord = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2)); ord = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
} }
rewriter.replaceOpWithNewOp<Torch::AtenLinalgVectorNormOp>( rewriter.replaceOpWithNewOp<Torch::AtenLinalgVectorNormOp>(

View File

@ -63,8 +63,8 @@ public:
}; };
static bool isTypeTriviallySafe(Type type) { static bool isTypeTriviallySafe(Type type) {
return type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType, return isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
Torch::StringType, Torch::NoneType, Torch::ValueTensorType>(); Torch::StringType, Torch::NoneType, Torch::ValueTensorType>(type);
} }
static bool isUseTreatedWithValueSemantics(OpOperand &use) { static bool isUseTreatedWithValueSemantics(OpOperand &use) {

View File

@ -36,8 +36,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
static LogicalResult checkType(Operation *op, Type type, static LogicalResult checkType(Operation *op, Type type,
bool actuallyEmitDiagnostics) { bool actuallyEmitDiagnostics) {
// Allow various scalar types that backends are expected to be able to handle. // Allow various scalar types that backends are expected to be able to handle.
if (type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType, if (isa<Torch::IntType, Torch::FloatType, Torch::BoolType, Torch::DeviceType>(
Torch::DeviceType>()) type))
return success(); return success();
// Backends are not expected to support dynamic computations on these types, // Backends are not expected to support dynamic computations on these types,

View File

@ -187,7 +187,7 @@ public:
auto it = originalReturnTypes.find(i); auto it = originalReturnTypes.find(i);
if (it == originalReturnTypes.end()) if (it == originalReturnTypes.end())
continue; continue;
auto originalType = it->second.cast<NonValueTensorType>(); auto originalType = cast<NonValueTensorType>(it->second);
rewriter.setInsertionPoint(returnOp); rewriter.setInsertionPoint(returnOp);
Value newReturnValue = copyTensorToType(rewriter, returnOp->getLoc(), Value newReturnValue = copyTensorToType(rewriter, returnOp->getLoc(),
originalType, operand.get()); originalType, operand.get());
@ -350,7 +350,7 @@ public:
auto it = originalTypes.find(operand.get()); auto it = originalTypes.find(operand.get());
if (it == originalTypes.end()) if (it == originalTypes.end())
continue; continue;
auto originalType = it->second.cast<BaseTensorType>(); auto originalType = cast<BaseTensorType>(it->second);
rewriter.setInsertionPoint(op); rewriter.setInsertionPoint(op);
Value newReturnValue = copyTensorToType(rewriter, op->getLoc(), Value newReturnValue = copyTensorToType(rewriter, op->getLoc(),
originalType, operand.get()); originalType, operand.get());

View File

@ -118,7 +118,7 @@ public:
if (auto optionalType = if (auto optionalType =
dyn_cast<OptionalType>(listType.getContainedType())) { dyn_cast<OptionalType>(listType.getContainedType())) {
if (!llvm::all_of(listConstruct.getElements(), [](Value val) { if (!llvm::all_of(listConstruct.getElements(), [](Value val) {
return val.getType().isa<NonValueTensorType, Torch::NoneType>(); return isa<NonValueTensorType, Torch::NoneType>(val.getType());
})) { })) {
rewriter.cancelOpModification(op); rewriter.cancelOpModification(op);
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(

View File

@ -81,7 +81,7 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable(
if (name.starts_with("valsem.")) if (name.starts_with("valsem."))
name = name.drop_front(strlen("valsem.")); name = name.drop_front(strlen("valsem."));
if (isa<OperatorOp>(op)) if (isa<OperatorOp>(op))
name = cast<OperatorOp>(op)->getAttr("name").cast<StringAttr>().getValue(); name = cast<StringAttr>(cast<OperatorOp>(op)->getAttr("name")).getValue();
std::string libFuncName = std::string libFuncName =
(getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str(); (getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str();
auto libFunc = library.lookupSymbol<func::FuncOp>(libFuncName); auto libFunc = library.lookupSymbol<func::FuncOp>(libFuncName);
@ -191,8 +191,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
// to match the library function signature. // to match the library function signature.
if (auto unionType = dyn_cast<Torch::UnionType>(desiredType)) { if (auto unionType = dyn_cast<Torch::UnionType>(desiredType)) {
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) { if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
return containedType return isa<Torch::IntType, Torch::FloatType, Torch::NoneType>(
.isa<Torch::IntType, Torch::FloatType, Torch::NoneType>(); containedType);
})) }))
return b.create<DerefineOp>(loc, desiredType, operand).getResult(); return b.create<DerefineOp>(loc, desiredType, operand).getResult();
} }

View File

@ -179,11 +179,10 @@ public:
"should have concrete Scalar Type."); "should have concrete Scalar Type.");
} }
Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType()); Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType());
auto impliedTypeFromInputType = auto impliedTypeFromInputType = cast<BaseTensorType>(
cast<BaseTensorType>(originalResultType) cast<BaseTensorType>(originalResultType)
.getWithSizesAndDtype(originalResultType.getOptionalSizes(), .getWithSizesAndDtype(originalResultType.getOptionalSizes(),
inputType) inputType));
.cast<BaseTensorType>();
op.getResult().setType(impliedTypeFromInputType); op.getResult().setType(impliedTypeFromInputType);
return success(); return success();

View File

@ -97,11 +97,10 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
} }
auto originalResultType = cast<BaseTensorType>(result.getType()); auto originalResultType = cast<BaseTensorType>(result.getType());
auto impliedTypesFromShape = auto impliedTypesFromShape = cast<BaseTensorType>(
cast<BaseTensorType>(originalResultType) cast<BaseTensorType>(originalResultType)
.getWithSizesAndDtype(ArrayRef(sizes), .getWithSizesAndDtype(ArrayRef(sizes),
originalResultType.getOptionalDtype()) originalResultType.getOptionalDtype()));
.cast<BaseTensorType>();
return updateCalculateOpResultTypes(op, resultNum, impliedTypesFromShape, return updateCalculateOpResultTypes(op, resultNum, impliedTypesFromShape,
rewriter); rewriter);

View File

@ -74,7 +74,7 @@ LogicalResult FromBuiltinTensorOp::verify() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult FromI1Op::fold(FoldAdaptor adaptor) { OpFoldResult FromI1Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::BoolAttr>(); auto attr = dyn_cast_or_null<mlir::BoolAttr>(adaptor.getOperand());
if (attr) { if (attr) {
return attr; return attr;
} else { } else {
@ -87,7 +87,7 @@ OpFoldResult FromI1Op::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult ToI1Op::fold(FoldAdaptor adaptor) { OpFoldResult ToI1Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::BoolAttr>(); auto attr = dyn_cast_or_null<mlir::BoolAttr>(adaptor.getOperand());
if (attr) { if (attr) {
return attr; return attr;
} else { } else {
@ -100,7 +100,7 @@ OpFoldResult ToI1Op::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) { OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>(); auto attr = dyn_cast_or_null<mlir::IntegerAttr>(adaptor.getOperand());
if (attr) { if (attr) {
return attr; return attr;
} else { } else {
@ -113,7 +113,7 @@ OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) { OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>(); auto attr = dyn_cast_or_null<mlir::IntegerAttr>(adaptor.getOperand());
if (attr) { if (attr) {
return attr; return attr;
} else { } else {
@ -126,7 +126,7 @@ OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) { OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>(); auto attr = dyn_cast_or_null<mlir::FloatAttr>(adaptor.getOperand());
if (attr) { if (attr) {
return attr; return attr;
} else { } else {
@ -139,7 +139,7 @@ OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult FromF64Op::fold(FoldAdaptor adaptor) { OpFoldResult FromF64Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>(); auto attr = dyn_cast_or_null<mlir::FloatAttr>(adaptor.getOperand());
if (attr) { if (attr) {
return attr; return attr;
} else { } else {

View File

@ -91,7 +91,7 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target,
return std::nullopt; return std::nullopt;
// Other input type to be converted to i64 are handled by other // Other input type to be converted to i64 are handled by other
// materializers. // materializers.
if (!inputs[0].getType().isa<Torch::IntType>()) if (!isa<Torch::IntType>(inputs[0].getType()))
return std::nullopt; return std::nullopt;
assert(inputs.size() == 1); assert(inputs.size() == 1);
return builder.create<ToI64Op>(loc, inputs[0]).getResult(); return builder.create<ToI64Op>(loc, inputs[0]).getResult();
@ -145,7 +145,7 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target,
return std::nullopt; return std::nullopt;
// Other input type to be converted to i64 are handled by other // Other input type to be converted to i64 are handled by other
// materializers. // materializers.
if (!inputs[0].getType().isa<Torch::GeneratorType>()) if (!isa<Torch::GeneratorType>(inputs[0].getType()))
return std::nullopt; return std::nullopt;
assert(inputs.size() == 1); assert(inputs.size() == 1);
return builder.create<GeneratorToI64Op>(loc, inputs[0]).getResult(); return builder.create<GeneratorToI64Op>(loc, inputs[0]).getResult();

View File

@ -56,7 +56,7 @@ void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); }
static bool isArgMemRefTypeValid(Type type) { static bool isArgMemRefTypeValid(Type type) {
if (auto memRefType = dyn_cast<MemRefType>(type)) { if (auto memRefType = dyn_cast<MemRefType>(type)) {
Type elemTy = memRefType.getElementType(); Type elemTy = memRefType.getElementType();
if (elemTy.isa<Float16Type, Float32Type, Float64Type>()) { if (isa<Float16Type, Float32Type, Float64Type>(elemTy)) {
return true; return true;
} else if (auto integerTy = dyn_cast<IntegerType>(elemTy)) { } else if (auto integerTy = dyn_cast<IntegerType>(elemTy)) {
if (integerTy.isSignlessInteger(64)) if (integerTy.isSignlessInteger(64))
@ -70,7 +70,7 @@ static bool isArgMemRefTypeValid(Type type) {
if (integerTy.isSignlessInteger(1)) if (integerTy.isSignlessInteger(1))
return true; return true;
} else if (auto complexTy = dyn_cast<ComplexType>(elemTy)) { } else if (auto complexTy = dyn_cast<ComplexType>(elemTy)) {
return complexTy.getElementType().isa<Float32Type, Float64Type>(); return isa<Float32Type, Float64Type>(complexTy.getElementType());
} }
} }
return false; return false;