mirror of https://github.com/llvm/torch-mlir
[NFC] Change to *cast instead of .*cast variants (#3405)
Member casts have been deprecated. Changing over a bunch of the member cast calls to the global templated variants to remove deprecation warnings.pull/3407/head
parent
4e05e2cd1e
commit
afca88a058
|
@ -54,13 +54,6 @@ cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLI
|
|||
|
||||
option(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER "Enables the ONNX C importer" OFF)
|
||||
|
||||
# TODO(#3299): migrate to from member x.cast<T>() to mlir::cast<T>(x).
|
||||
if(MSVC)
|
||||
add_compile_options(/wd4996)
|
||||
else()
|
||||
add_compile_options(-Wno-deprecated-declarations)
|
||||
endif()
|
||||
|
||||
macro(torch_mlir_enable_werror)
|
||||
if(TORCH_MLIR_ENABLE_WERROR_FLAG)
|
||||
if(NOT MSVC)
|
||||
|
|
|
@ -125,7 +125,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
|||
llvm::copy_if(getInputOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperand) {
|
||||
return opOperand->get().getType().template isa<MemRefType>();
|
||||
return isa<MemRefType>(opOperand->get().getType());
|
||||
});
|
||||
return result;
|
||||
}]
|
||||
|
@ -144,7 +144,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
|||
llvm::copy_if(getInputOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperand) {
|
||||
return opOperand->get().getType().template isa<RankedTensorType>();
|
||||
return isa<RankedTensorType>(opOperand->get().getType());
|
||||
});
|
||||
return result;
|
||||
}]
|
||||
|
@ -200,7 +200,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
|||
llvm::copy_if(getOutputOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperand) {
|
||||
return opOperand->get().getType().template isa<MemRefType>();
|
||||
return isa<MemRefType>(opOperand->get().getType());
|
||||
});
|
||||
return result;
|
||||
}]
|
||||
|
@ -219,7 +219,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
|||
llvm::copy_if(getOutputOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperand) {
|
||||
return opOperand->get().getType().template isa<RankedTensorType>();
|
||||
return isa<RankedTensorType>(opOperand->get().getType());
|
||||
});
|
||||
return result;
|
||||
}]
|
||||
|
@ -238,7 +238,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
|||
llvm::transform(getOutputBufferOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperands) {
|
||||
return opOperands->get().getType().cast<MemRefType>();
|
||||
return cast<MemRefType>(opOperands->get().getType());
|
||||
});
|
||||
return result;
|
||||
}]
|
||||
|
@ -257,7 +257,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
|||
llvm::transform(getOutputTensorOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperands) {
|
||||
return opOperands->get().getType().cast<RankedTensorType>();
|
||||
return cast<RankedTensorType>(opOperands->get().getType());
|
||||
});
|
||||
return result;
|
||||
}]
|
||||
|
@ -318,7 +318,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
|||
/*args=*/(ins "OpOperand *":$opOperand),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
if (!opOperand->get().getType().template isa<RankedTensorType>())
|
||||
if (!isa<RankedTensorType>(opOperand->get().getType()))
|
||||
return false;
|
||||
if (opOperand->getOperandNumber() < $_op.getNumInputs())
|
||||
return true;
|
||||
|
@ -334,7 +334,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
|||
/*args=*/(ins "OpOperand *":$opOperand),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
if (!opOperand->get().getType().template isa<RankedTensorType>())
|
||||
if (!isa<RankedTensorType>(opOperand->get().getType()))
|
||||
return false;
|
||||
if (opOperand->getOperandNumber() >= $_op.getNumInputs())
|
||||
return true;
|
||||
|
@ -367,7 +367,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
|||
/*defaultImplementation=*/[{
|
||||
assert(opOperand->getOwner() == this->getOperation());
|
||||
if (auto shapedType =
|
||||
opOperand->get().getType().template dyn_cast<ShapedType>())
|
||||
dyn_cast<ShapedType>(opOperand->get().getType()))
|
||||
return shapedType.getRank();
|
||||
return 0;
|
||||
}]
|
||||
|
@ -383,7 +383,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
|||
/*defaultImplementation=*/[{
|
||||
assert(opOperand->getOwner() == this->getOperation());
|
||||
if (auto shapedType =
|
||||
opOperand->get().getType().template dyn_cast<ShapedType>())
|
||||
dyn_cast<ShapedType>(opOperand->get().getType()))
|
||||
return shapedType.getShape();
|
||||
return {};
|
||||
}]
|
||||
|
@ -398,7 +398,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
|||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(opOperand->getOwner() == this->getOperation());
|
||||
return !opOperand->get().getType().template isa<ShapedType>();
|
||||
return !isa<ShapedType>(opOperand->get().getType());
|
||||
}]
|
||||
>,
|
||||
//===------------------------------------------------------------------===//
|
||||
|
@ -416,10 +416,10 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
|||
return this->getOperation()->getNumResults() == 0 &&
|
||||
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
|
||||
return isScalar(opOperand) ||
|
||||
opOperand->get().getType().template isa<MemRefType>();
|
||||
isa<MemRefType>(opOperand->get().getType());
|
||||
}) &&
|
||||
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
|
||||
return opOperand->get().getType().template isa<MemRefType>();
|
||||
return isa<MemRefType>(opOperand->get().getType());
|
||||
});
|
||||
}]
|
||||
>,
|
||||
|
@ -435,10 +435,10 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
|||
return
|
||||
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
|
||||
return isScalar(opOperand) ||
|
||||
opOperand->get().getType().template isa<RankedTensorType>();
|
||||
isa<RankedTensorType>(opOperand->get().getType());
|
||||
}) &&
|
||||
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
|
||||
return opOperand->get().getType().template isa<RankedTensorType>();
|
||||
return isa<RankedTensorType>(opOperand->get().getType());
|
||||
});
|
||||
}]
|
||||
>,
|
||||
|
@ -478,8 +478,8 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
|||
|
||||
private:
|
||||
void setOperandSegmentAt(unsigned idx, unsigned val) {
|
||||
auto attr = (*this)->getAttr("operand_segment_sizes")
|
||||
.cast<DenseIntElementsAttr>();
|
||||
auto attr = cast<DenseIntElementsAttr>((*this)->getAttr("operand_segment_sizes")
|
||||
);
|
||||
unsigned i = 0;
|
||||
auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32),
|
||||
[&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
|
||||
|
|
|
@ -88,7 +88,7 @@ def TMTensor_ScanOp : TMTensor_Op<"scan",
|
|||
return getOutputOperand(0)->get();
|
||||
}
|
||||
ShapedType getOperandType() {
|
||||
return input().getType().cast<ShapedType>();
|
||||
return cast<ShapedType>(input().getType());
|
||||
}
|
||||
int64_t getOperandRank() {
|
||||
return getOperandType().getRank();
|
||||
|
@ -151,10 +151,10 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
|
|||
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
|
||||
|
||||
int64_t getIndexDepth() {
|
||||
return getInputOperand(1)
|
||||
return cast<ShapedType>(getInputOperand(1)
|
||||
->get()
|
||||
.getType()
|
||||
.cast<ShapedType>()
|
||||
)
|
||||
.getShape()
|
||||
.back();
|
||||
}
|
||||
|
@ -164,7 +164,7 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
|
|||
}
|
||||
|
||||
ShapedType getUpdateType() {
|
||||
return updates().getType().cast<ShapedType>();
|
||||
return cast<ShapedType>(updates().getType());
|
||||
}
|
||||
|
||||
Value indices() {
|
||||
|
@ -172,7 +172,7 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
|
|||
}
|
||||
|
||||
ShapedType getIndicesType() {
|
||||
return indices().getType().cast<ShapedType>();
|
||||
return cast<ShapedType>(indices().getType());
|
||||
}
|
||||
|
||||
Value original() {
|
||||
|
@ -180,11 +180,11 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
|
|||
}
|
||||
|
||||
ShapedType getOriginalType() {
|
||||
return original().getType().cast<ShapedType>();
|
||||
return cast<ShapedType>(original().getType());
|
||||
}
|
||||
|
||||
int64_t getUpdateSliceRank() {
|
||||
return updates().getType().cast<ShapedType>().getRank() - 1;
|
||||
return cast<ShapedType>(updates().getType()).getRank() - 1;
|
||||
}
|
||||
|
||||
bool isScalarUpdate() {
|
||||
|
@ -224,7 +224,7 @@ def TMTensor_SortOp : TMTensor_Op<"sort",
|
|||
return getOutputs()[index];
|
||||
}
|
||||
ShapedType getOperandType(int index) {
|
||||
return operand(index).getType().cast<ShapedType>();
|
||||
return cast<ShapedType>(operand(index).getType());
|
||||
}
|
||||
int64_t getOperandRank() {
|
||||
return getOperandType(0).getRank();
|
||||
|
@ -291,16 +291,16 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
|
|||
return getOutputOperand(0)->get();
|
||||
}
|
||||
ShapedType getQueryType() {
|
||||
return getQuery().getType().cast<ShapedType>();
|
||||
return cast<ShapedType>(getQuery().getType());
|
||||
}
|
||||
ShapedType getKeyType() {
|
||||
return getKey().getType().cast<ShapedType>();
|
||||
return cast<ShapedType>(getKey().getType());
|
||||
}
|
||||
ShapedType getValueType() {
|
||||
return getValue().getType().cast<ShapedType>();
|
||||
return cast<ShapedType>(getValue().getType());
|
||||
}
|
||||
ShapedType getOutputType() {
|
||||
return getOutput().getType().cast<ShapedType>();
|
||||
return cast<ShapedType>(getOutput().getType());
|
||||
}
|
||||
int64_t getQueryRank() {
|
||||
return getQueryType().getRank();
|
||||
|
|
|
@ -61,12 +61,12 @@ struct onnx_list_of_constant_ints_op_binder {
|
|||
|
||||
bool match(Operation *op) {
|
||||
auto constOp = dyn_cast<Torch::OperatorOp>(op);
|
||||
if (!constOp || !constOp.getName().equals("onnx.Constant"))
|
||||
if (!constOp || !(constOp.getName() == "onnx.Constant"))
|
||||
return false;
|
||||
|
||||
if (DenseResourceElementsAttr attr =
|
||||
constOp->getAttr("torch.onnx.value")
|
||||
.dyn_cast_or_null<DenseResourceElementsAttr>()) {
|
||||
dyn_cast_or_null<DenseResourceElementsAttr>(
|
||||
constOp->getAttr("torch.onnx.value"))) {
|
||||
// Bytes are stored in little endian order. Big endian support will
|
||||
// require swizzling.
|
||||
if (!Endian::little) {
|
||||
|
|
|
@ -190,7 +190,7 @@ struct torch_list_of_optional_constant_ints_op_binder {
|
|||
int64_t num;
|
||||
if (matchPattern(value, m_TorchConstantInt(&num)))
|
||||
bind_values.push_back(num);
|
||||
else if (value.getType().isa<Torch::NoneType>())
|
||||
else if (isa<Torch::NoneType>(value.getType()))
|
||||
bind_values.push_back(std::nullopt);
|
||||
else
|
||||
return false;
|
||||
|
|
|
@ -442,8 +442,8 @@ def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [
|
|||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
Type getKeyType() { return getType().cast<DictType>().getKeyType(); }
|
||||
Type getValueType() { return getType().cast<DictType>().getValueType(); }
|
||||
Type getKeyType() { return cast<DictType>(getType()).getKeyType(); }
|
||||
Type getValueType() { return cast<DictType>(getType()).getValueType(); }
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -1003,7 +1003,7 @@ def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [
|
|||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
TypesMatchWith<"operand is corresponding !torch.vtensor",
|
||||
"result", "operand",
|
||||
"$_self.cast<NonValueTensorType>().getWithValueSemantics()">,
|
||||
"cast<NonValueTensorType>($_self).getWithValueSemantics()">,
|
||||
]> {
|
||||
let summary = "Create a !torch.tensor with the same contents as the operand";
|
||||
let description = [{
|
||||
|
@ -1036,7 +1036,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
|
|||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
TypesMatchWith<"operand is corresponding !torch.tensor",
|
||||
"result", "operand",
|
||||
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()">,
|
||||
"cast<ValueTensorType>($_self).getWithoutValueSemantics()">,
|
||||
]> {
|
||||
let summary = "Create a !torch.vtensor with the same contents as the operand";
|
||||
let description = [{
|
||||
|
@ -1064,7 +1064,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
|
|||
def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [
|
||||
TypesMatchWith<"overwritten tensor type is corresponding !torch.tensor of value tensor type",
|
||||
"value", "overwritten",
|
||||
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()">
|
||||
"cast<ValueTensorType>($_self).getWithoutValueSemantics()">
|
||||
]> {
|
||||
let summary = "Ovewrite the contents of tensor with values from another.";
|
||||
let description = [{
|
||||
|
|
|
@ -199,7 +199,7 @@ def Torch_ValueTensorType : AnyTorchTensorType<"ValueTensor", "vtensor"> {
|
|||
}
|
||||
|
||||
def AnyTorchTensorType : Type<
|
||||
CPred<"$_self.isa<::mlir::torch::Torch::BaseTensorType>()">,
|
||||
CPred<"isa<::mlir::torch::Torch::BaseTensorType>($_self)">,
|
||||
"Any Torch tensor type"
|
||||
>;
|
||||
|
||||
|
@ -410,11 +410,11 @@ def AnyTorchOptionalDeviceType:
|
|||
def AnyTorchOptionalGeneratorType:
|
||||
OptionalOf<Torch_GeneratorType, "Optional torch Generator type">;
|
||||
|
||||
def IsListTypePred : CPred<"$_self.isa<::mlir::torch::Torch::ListType>()">;
|
||||
def IsListTypePred : CPred<"isa<::mlir::torch::Torch::ListType>($_self)">;
|
||||
class ListOf<list<Type> allowedTypes, string descr> :
|
||||
ContainerType<AnyTypeOf<allowedTypes>,
|
||||
IsListTypePred,
|
||||
"$_self.cast<::mlir::torch::Torch::ListType>().getContainedType()",
|
||||
"cast<::mlir::torch::Torch::ListType>($_self).getContainedType()",
|
||||
descr, "::mlir::torch::Torch::ListType">;
|
||||
|
||||
def AnyTorchListOfTorchBoolType : ListOf<[Torch_BoolType], "Bool list type (bool[])">;
|
||||
|
|
|
@ -26,7 +26,7 @@ bool torchMlirTypeIsValidSubtype(MlirType subtype, MlirType type) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchNnModule(MlirType t) {
|
||||
return unwrap(t).isa<Torch::NnModuleType>();
|
||||
return isa<Torch::NnModuleType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchNnModuleTypeGet(MlirContext context,
|
||||
|
@ -43,7 +43,7 @@ MlirTypeID torchMlirTorchNnModuleTypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchOptional(MlirType t) {
|
||||
return unwrap(t).isa<Torch::OptionalType>();
|
||||
return isa<Torch::OptionalType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) {
|
||||
|
@ -64,7 +64,7 @@ MlirTypeID torchMlirTorchOptionalTypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchTuple(MlirType t) {
|
||||
return unwrap(t).isa<Torch::TupleType>();
|
||||
return isa<Torch::TupleType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchTupleTypeGet(MlirContext context,
|
||||
|
@ -95,7 +95,7 @@ MlirTypeID torchMlirTorchTupleTypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchUnion(MlirType t) {
|
||||
return unwrap(t).isa<Torch::UnionType>();
|
||||
return isa<Torch::UnionType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchUnionTypeGet(MlirContext context,
|
||||
|
@ -126,7 +126,7 @@ MlirTypeID torchMlirTorchUnionTypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchList(MlirType t) {
|
||||
return unwrap(t).isa<Torch::ListType>();
|
||||
return isa<Torch::ListType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchListTypeGet(MlirType containedType) {
|
||||
|
@ -146,7 +146,7 @@ MlirTypeID torchMlirTorchListTypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchDevice(MlirType t) {
|
||||
return unwrap(t).isa<Torch::DeviceType>();
|
||||
return isa<Torch::DeviceType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchDeviceTypeGet(MlirContext context) {
|
||||
|
@ -162,7 +162,7 @@ MlirTypeID torchMlirTorchDeviceTypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchGenerator(MlirType t) {
|
||||
return unwrap(t).isa<Torch::GeneratorType>();
|
||||
return isa<Torch::GeneratorType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchGeneratorTypeGet(MlirContext context) {
|
||||
|
@ -178,7 +178,7 @@ MlirTypeID torchMlirTorchGeneratorTypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchBool(MlirType t) {
|
||||
return unwrap(t).isa<Torch::BoolType>();
|
||||
return isa<Torch::BoolType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchBoolTypeGet(MlirContext context) {
|
||||
|
@ -194,7 +194,7 @@ MlirTypeID torchMlirTorchBoolTypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchInt(MlirType t) {
|
||||
return unwrap(t).isa<Torch::IntType>();
|
||||
return isa<Torch::IntType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchIntTypeGet(MlirContext context) {
|
||||
|
@ -210,7 +210,7 @@ MlirTypeID torchMlirTorchIntTypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchFloat(MlirType t) {
|
||||
return unwrap(t).isa<Torch::FloatType>();
|
||||
return isa<Torch::FloatType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchFloatTypeGet(MlirContext context) {
|
||||
|
@ -226,7 +226,7 @@ MlirTypeID torchMlirTorchFloatTypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchLinearParams(MlirType t) {
|
||||
return unwrap(t).isa<Torch::LinearParamsType>();
|
||||
return isa<Torch::LinearParamsType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context) {
|
||||
|
@ -242,7 +242,7 @@ MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchQInt8(MlirType t) {
|
||||
return unwrap(t).isa<Torch::QInt8Type>();
|
||||
return isa<Torch::QInt8Type>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchQInt8TypeGet(MlirContext context) {
|
||||
|
@ -258,7 +258,7 @@ MlirTypeID torchMlirTorchQInt8TypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchQUInt8(MlirType t) {
|
||||
return unwrap(t).isa<Torch::QUInt8Type>();
|
||||
return isa<Torch::QUInt8Type>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchQUInt8TypeGet(MlirContext context) {
|
||||
|
@ -274,7 +274,7 @@ MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchNonValueTensor(MlirType t) {
|
||||
return unwrap(t).isa<Torch::NonValueTensorType>();
|
||||
return isa<Torch::NonValueTensorType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
|
||||
|
@ -341,7 +341,7 @@ MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchValueTensor(MlirType t) {
|
||||
return unwrap(t).isa<Torch::ValueTensorType>();
|
||||
return isa<Torch::ValueTensorType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
|
||||
|
@ -408,7 +408,7 @@ MlirTypeID torchMlirTorchValueTensorTypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchNone(MlirType t) {
|
||||
return unwrap(t).isa<Torch::NoneType>();
|
||||
return isa<Torch::NoneType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchNoneTypeGet(MlirContext context) {
|
||||
|
@ -424,7 +424,7 @@ MlirTypeID torchMlirTorchNoneTypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchString(MlirType t) {
|
||||
return unwrap(t).isa<Torch::StringType>();
|
||||
return isa<Torch::StringType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchStringTypeGet(MlirContext context) {
|
||||
|
@ -440,7 +440,7 @@ MlirTypeID torchMlirTorchStringTypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchAny(MlirType t) {
|
||||
return unwrap(t).isa<Torch::AnyType>();
|
||||
return isa<Torch::AnyType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchAnyTypeGet(MlirContext context) {
|
||||
|
@ -456,7 +456,7 @@ MlirTypeID torchMlirTorchAnyTypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchNumber(MlirType t) {
|
||||
return unwrap(t).isa<Torch::NumberType>();
|
||||
return isa<Torch::NumberType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchNumberTypeGet(MlirContext context) {
|
||||
|
@ -472,7 +472,7 @@ MlirTypeID torchMlirTorchNumberTypeGetTypeID() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchDict(MlirType t) {
|
||||
return unwrap(t).isa<Torch::DictType>();
|
||||
return isa<Torch::DictType>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) {
|
||||
|
|
|
@ -546,12 +546,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
Value shuffledPaddingList =
|
||||
createConstantIntList(binder, rewriter, padding);
|
||||
Value zero;
|
||||
if (resultTypeOut.getDtype().isa<FloatType>()) {
|
||||
if (isa<FloatType>(resultTypeOut.getDtype())) {
|
||||
zero = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getF64FloatAttr(
|
||||
std::numeric_limits<double>::lowest()));
|
||||
} else if (resultTypeOut.getDtype().isa<IntegerType>()) {
|
||||
} else if (isa<IntegerType>(resultTypeOut.getDtype())) {
|
||||
zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(
|
||||
std::numeric_limits<int64_t>::lowest()));
|
||||
|
@ -1295,7 +1295,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
auto inputTensorType = operand.getType().cast<Torch::ValueTensorType>();
|
||||
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
|
||||
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expected input type having sizes");
|
||||
|
@ -1509,10 +1509,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
|
||||
if (!constantValue) {
|
||||
auto dataTensorType = cast<Torch::ValueTensorType>(data.getType());
|
||||
if (dataTensorType.getDtype().isa<IntegerType>())
|
||||
if (isa<IntegerType>(dataTensorType.getDtype()))
|
||||
constantValue = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
if (dataTensorType.getDtype().isa<FloatType>())
|
||||
if (isa<FloatType>(dataTensorType.getDtype()))
|
||||
constantValue = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(0.0f));
|
||||
|
||||
|
|
|
@ -1023,9 +1023,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
Value constFalse =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||
auto size = data.getType()
|
||||
.dyn_cast<Torch::ValueTensorType>()
|
||||
.getOptionalSizes();
|
||||
auto size =
|
||||
dyn_cast<Torch::ValueTensorType>(data.getType()).getOptionalSizes();
|
||||
auto f64ResultType = rewriter.getType<Torch::ValueTensorType>(
|
||||
size, rewriter.getF64Type());
|
||||
Value dataCast = rewriter.create<Torch::AtenToDtypeOp>(
|
||||
|
@ -2906,8 +2905,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
scalesValueList = noneVal;
|
||||
sizesValueList = getValueList(sizeOperand);
|
||||
}
|
||||
if (scalesValueList.getType().isa<Torch::NoneType>() &&
|
||||
sizesValueList.getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(scalesValueList.getType()) &&
|
||||
isa<Torch::NoneType>(sizesValueList.getType())) {
|
||||
return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode");
|
||||
}
|
||||
rewriter
|
||||
|
|
|
@ -1868,9 +1868,8 @@ public:
|
|||
const TypeConverter *typeConverter = getTypeConverter();
|
||||
|
||||
auto input = adaptor.getSelf();
|
||||
RankedTensorType resultType =
|
||||
typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op->getResult(0).getType()));
|
||||
|
||||
SmallVector<Value> resultShape;
|
||||
SmallVector<Value> offsets;
|
||||
|
@ -2107,9 +2106,8 @@ public:
|
|||
|
||||
auto input = adaptor.getSelf();
|
||||
|
||||
RankedTensorType resultType =
|
||||
typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op->getResult(0).getType()));
|
||||
|
||||
SmallVector<Value> resultShape;
|
||||
SmallVector<Value> offsets;
|
||||
|
@ -2343,9 +2341,8 @@ public:
|
|||
op, "diagonal dimensions cannot be identical");
|
||||
|
||||
Type elementType = inputType.getElementType();
|
||||
RankedTensorType outputType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType outputType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
Location loc = op.getLoc();
|
||||
|
||||
Value dim1Size, dim2Size;
|
||||
|
@ -2581,9 +2578,8 @@ public:
|
|||
})
|
||||
.getResult(0);
|
||||
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, resultTensor);
|
||||
return success();
|
||||
|
@ -2608,9 +2604,8 @@ public:
|
|||
return failure();
|
||||
// Conversion is completed specified by information in the sparse tensor
|
||||
// type. Thus, we can rewrite all legalizedNames to the same construct.
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>(
|
||||
op, resultType, adaptor.getOperands()[0]);
|
||||
return success();
|
||||
|
|
|
@ -845,7 +845,7 @@ public:
|
|||
outputSizeIntValues = getTypeConvertedValues(
|
||||
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
|
||||
|
||||
if (!op.getScalesH().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getScalesH().getType())) {
|
||||
// Convert float values to int values.
|
||||
// int_value = (int64_t)ceil(float_value)
|
||||
Value ceilVal = rewriter.create<math::CeilOp>(loc, adaptor.getScalesH());
|
||||
|
@ -858,7 +858,7 @@ public:
|
|||
scaleFactorsInt.push_back(scaleFactorVal);
|
||||
}
|
||||
|
||||
if (!op.getScalesW().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getScalesW().getType())) {
|
||||
// Convert float values to int values.
|
||||
// int_value = (int64_t)ceil(float_value)
|
||||
Value ceilVal = rewriter.create<math::CeilOp>(loc, adaptor.getScalesW());
|
||||
|
@ -1006,7 +1006,7 @@ public:
|
|||
unsigned hDimOffset = 2;
|
||||
|
||||
SmallVector<Value, 2> scaleFactorsFloatValues;
|
||||
if (!op.getScalesH().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getScalesH().getType())) {
|
||||
scaleFactorsFloatValues.push_back(adaptor.getScalesH());
|
||||
} else {
|
||||
auto scaleFactorVal = rewriter.create<arith::DivFOp>(
|
||||
|
@ -1019,7 +1019,7 @@ public:
|
|||
scaleFactorsFloatValues.push_back(scaleFactorVal);
|
||||
}
|
||||
|
||||
if (!op.getScalesW().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getScalesW().getType())) {
|
||||
scaleFactorsFloatValues.push_back(adaptor.getScalesW());
|
||||
} else {
|
||||
auto scaleFactorVal = rewriter.create<arith::DivFOp>(
|
||||
|
|
|
@ -41,7 +41,7 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
|
|||
return;
|
||||
int64_t minSI = -(1 << (numBits - 1));
|
||||
Value minSIValue = rewriter.create<arith::ConstantIntOp>(
|
||||
loc, minSI, zp.getType().cast<mlir::IntegerType>().getWidth());
|
||||
loc, minSI, cast<mlir::IntegerType>(zp.getType()).getWidth());
|
||||
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
|
||||
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
|
||||
arg = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||
|
@ -1057,10 +1057,10 @@ public:
|
|||
loc, getAsOpFoldResult(outDims), accumulatorDType);
|
||||
|
||||
Value outputTensor;
|
||||
if (accumulatorDType != resultDTy && !bias.getType().isa<Torch::NoneType>())
|
||||
if (accumulatorDType != resultDTy && !isa<Torch::NoneType>(bias.getType()))
|
||||
bias = torch_to_linalg::convertTensorToElementType(rewriter, loc, bias,
|
||||
accumulatorDType);
|
||||
if (bias.getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(bias.getType())) {
|
||||
Value c0;
|
||||
if (isa<mlir::FloatType>(accumulatorDType)) {
|
||||
c0 = rewriter.create<arith::ConstantOp>(
|
||||
|
|
|
@ -409,10 +409,8 @@ public:
|
|||
Value self = adaptor.getSelf();
|
||||
RankedTensorType selfType = cast<RankedTensorType>(self.getType());
|
||||
Type elementType = selfType.getElementType();
|
||||
RankedTensorType indicesRankedTensorType =
|
||||
getTypeConverter()
|
||||
->convertType(op->getResult(1).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType indicesRankedTensorType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(1).getType()));
|
||||
|
||||
// TODO: Add support for 3D inputs.
|
||||
if (selfType.getRank() == 3)
|
||||
|
@ -717,10 +715,10 @@ public:
|
|||
|
||||
Location loc = op->getLoc();
|
||||
const TypeConverter *typeConverter = opConversionPattern.getTypeConverter();
|
||||
outputType = typeConverter->convertType(op.getResult0().getType())
|
||||
.template cast<RankedTensorType>();
|
||||
auxTensorType = typeConverter->convertType(op.getResult1().getType())
|
||||
.template cast<RankedTensorType>();
|
||||
outputType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op.getResult0().getType()));
|
||||
auxTensorType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op.getResult1().getType()));
|
||||
Type auxTensorElementType = auxTensorType.getElementType();
|
||||
auto smallestFPValueAttr = rewriter.getFloatAttr(
|
||||
elementType,
|
||||
|
@ -799,8 +797,8 @@ public:
|
|||
|
||||
Location loc = op->getLoc();
|
||||
const TypeConverter *typeConverter = opConversionPattern.getTypeConverter();
|
||||
outputType = typeConverter->convertType(op.getResult().getType())
|
||||
.template cast<RankedTensorType>();
|
||||
outputType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op.getResult().getType()));
|
||||
buffVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, elementType, rewriter.getFloatAttr(elementType, 0));
|
||||
auxTensor = rewriter.create<tensor::EmptyOp>(
|
||||
|
|
|
@ -42,9 +42,8 @@ public:
|
|||
|
||||
if (train)
|
||||
return failure();
|
||||
auto resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
|
||||
adaptor.getInput());
|
||||
return success();
|
||||
|
@ -60,8 +59,8 @@ static Value toLinearIndex(OpBuilder &b, Location loc,
|
|||
Value result =
|
||||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(b.getI64Type()));
|
||||
for (auto [index, stride] : llvm::zip(indicesIntValues, shapeIntValues)) {
|
||||
assert(index.getType().isa<mlir::IntegerType>() &&
|
||||
stride.getType().isa<mlir::IntegerType>() &&
|
||||
assert(isa<mlir::IntegerType>(index.getType()) &&
|
||||
isa<mlir::IntegerType>(stride.getType()) &&
|
||||
"Input arrays to `toLinearIndex` must only contain values of type "
|
||||
"`mlir::IntegerType`");
|
||||
Value mul = b.create<arith::MulIOp>(loc, result, stride);
|
||||
|
@ -129,7 +128,7 @@ public:
|
|||
if (!isa<mlir::FloatType>(elemTy))
|
||||
return rewriter.notifyMatchFailure(op, "This op only support float type");
|
||||
|
||||
if (!generator.getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(generator.getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to be None because only global default "
|
||||
"generator is supported");
|
||||
|
@ -180,7 +179,7 @@ public:
|
|||
b.create<arith::MulFOp>(loc, updateFloat, scale);
|
||||
Value res = b.create<arith::AddFOp>(loc, updateScaled, min);
|
||||
Value truncRes = res;
|
||||
if (elemTy.isa<Float16Type, Float32Type>())
|
||||
if (isa<Float16Type, Float32Type>(elemTy))
|
||||
truncRes = b.create<arith::TruncFOp>(loc, elemTy, res);
|
||||
b.create<linalg::YieldOp>(loc, truncRes);
|
||||
})
|
||||
|
|
|
@ -86,11 +86,8 @@ public:
|
|||
bool isUnsigned = false;
|
||||
if (!isa<mlir::FloatType>(inElementType)) {
|
||||
if (isa<mlir::IntegerType>(inElementType)) {
|
||||
auto integerTy = op.getSelf()
|
||||
.getType()
|
||||
.template cast<BaseTensorType>()
|
||||
.getDtype()
|
||||
.template dyn_cast<mlir::IntegerType>();
|
||||
auto integerTy = dyn_cast<mlir::IntegerType>(
|
||||
cast<BaseTensorType>(op.getSelf().getType()).getDtype());
|
||||
isUnsigned = integerTy.isUnsigned();
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -280,7 +277,7 @@ public:
|
|||
|
||||
static Value createAbsOpForNormOps(OpBuilder &b, Location loc, Value elem,
|
||||
Type resultElementType) {
|
||||
if (elem.getType().isa<mlir::ComplexType>()) {
|
||||
if (isa<mlir::ComplexType>(elem.getType())) {
|
||||
return b.create<complex::AbsOp>(loc, elem);
|
||||
}
|
||||
|
||||
|
@ -376,11 +373,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
|
|||
if (isa<mlir::FloatType>(resultElementType))
|
||||
return b.create<arith::MaximumFOp>(loc, self, result);
|
||||
else if (isa<mlir::IntegerType>(resultElementType)) {
|
||||
IntegerType intType = max.getSelf()
|
||||
.getType()
|
||||
.cast<BaseTensorType>()
|
||||
.getDtype()
|
||||
.dyn_cast<mlir::IntegerType>();
|
||||
IntegerType intType = dyn_cast<mlir::IntegerType>(
|
||||
cast<BaseTensorType>(max.getSelf().getType()).getDtype());
|
||||
if (intType.isUnsigned())
|
||||
return b.create<arith::MaxUIOp>(loc, self, result);
|
||||
if (intType.isSigned())
|
||||
|
@ -393,11 +387,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
|
|||
if (isa<mlir::FloatType>(resultElementType))
|
||||
return b.create<arith::MinimumFOp>(loc, self, result);
|
||||
else if (isa<mlir::IntegerType>(resultElementType)) {
|
||||
IntegerType intType = min.getSelf()
|
||||
.getType()
|
||||
.cast<BaseTensorType>()
|
||||
.getDtype()
|
||||
.dyn_cast<mlir::IntegerType>();
|
||||
IntegerType intType = dyn_cast<mlir::IntegerType>(
|
||||
cast<BaseTensorType>(min.getSelf().getType()).getDtype());
|
||||
if (intType.isUnsigned())
|
||||
return b.create<arith::MinUIOp>(loc, self, result);
|
||||
if (intType.isSigned())
|
||||
|
@ -657,9 +648,8 @@ public:
|
|||
return opInfo;
|
||||
|
||||
Location loc = op->getLoc();
|
||||
auto resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
Type elemType = resultType.getElementType();
|
||||
LogicalResult elemTypeCheck =
|
||||
validateReductionElementType(op, elemType, rewriter);
|
||||
|
|
|
@ -179,15 +179,13 @@ public:
|
|||
|
||||
for (auto i : {TOP, VCENTER, BOTTOM}) {
|
||||
for (auto j : {LEFT, HCENTER, RIGHT}) {
|
||||
auto constVtile{
|
||||
auto constVtile{dyn_cast_or_null<mlir::IntegerAttr>(
|
||||
mlir::dyn_cast<mlir::arith::ConstantOp>(vTile[i].getDefiningOp())
|
||||
.getValue()
|
||||
.dyn_cast_or_null<mlir::IntegerAttr>()};
|
||||
.getValue())};
|
||||
|
||||
auto constHtile{
|
||||
auto constHtile{dyn_cast_or_null<mlir::IntegerAttr>(
|
||||
mlir::dyn_cast<mlir::arith::ConstantOp>(hTile[j].getDefiningOp())
|
||||
.getValue()
|
||||
.dyn_cast_or_null<mlir::IntegerAttr>()};
|
||||
.getValue())};
|
||||
auto vSize = constVtile.getInt();
|
||||
auto hSize = constHtile.getInt();
|
||||
|
||||
|
@ -369,8 +367,8 @@ public:
|
|||
for (auto size : resultSize)
|
||||
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
|
||||
|
||||
auto resultType = typeConverter->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
auto resultType =
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
Type resultElementType;
|
||||
if (isa<Torch::NoneType>(op.getDtype().getType())) {
|
||||
resultElementType = resultType.getElementType();
|
||||
|
@ -426,7 +424,7 @@ public:
|
|||
op, "unimplemented: pin_memory must be either None or false");
|
||||
|
||||
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
||||
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType())) {
|
||||
int64_t memoryFormat;
|
||||
if (!matchPattern(op.getMemoryFormat(),
|
||||
m_TorchConstantInt(&memoryFormat)))
|
||||
|
@ -441,7 +439,7 @@ public:
|
|||
}
|
||||
|
||||
// TODO: Add support for device arg other than cpu.
|
||||
if (!op.getDevice().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getDevice().getType())) {
|
||||
std::string device;
|
||||
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -453,7 +451,7 @@ public:
|
|||
|
||||
// TODO: Add support for non-strided layout.
|
||||
// torch.layout is by default strided i.e. 0.
|
||||
if (!op.getLayout().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getLayout().getType())) {
|
||||
int64_t tensorLayout;
|
||||
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -478,7 +476,7 @@ public:
|
|||
auto resultType =
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
Type resultElementType;
|
||||
if (op.getDtype().getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(op.getDtype().getType())) {
|
||||
resultElementType = getDefaultDtypeForTorchScalar(
|
||||
Torch::FloatType::get(op->getContext()));
|
||||
} else {
|
||||
|
@ -527,7 +525,7 @@ public:
|
|||
|
||||
// The pin_memory should be either `False` or `none`.
|
||||
bool pinMemory;
|
||||
if (!op.getPinMemory().getType().isa<Torch::NoneType>() &&
|
||||
if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
|
||||
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
||||
pinMemory)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -536,9 +534,8 @@ public:
|
|||
|
||||
Location loc = op.getLoc();
|
||||
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||
RankedTensorType resultType =
|
||||
typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op->getResult(0).getType()));
|
||||
Type dtype = resultType.getElementType();
|
||||
Value start =
|
||||
convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype);
|
||||
|
|
|
@ -138,17 +138,16 @@ public:
|
|||
requires_grad = tensorFloatOp.getRequiresGrad();
|
||||
}
|
||||
// TODO: Dtype conversion.
|
||||
if (!dtype.getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(dtype.getType()))
|
||||
return rewriter.notifyMatchFailure(op, "Unimplemented non-None dtype");
|
||||
|
||||
// TODO: Device information.
|
||||
if (!device.getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(device.getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented non-None device information");
|
||||
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
Type outElementType = resultType.getElementType();
|
||||
Value elemValProm =
|
||||
convertScalarToDtype(rewriter, loc, elemVal, outElementType);
|
||||
|
@ -171,9 +170,8 @@ public:
|
|||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op.getLoc();
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
Type outElementType = resultType.getElementType();
|
||||
Value elemVal = adaptor.getA();
|
||||
Value elemValProm =
|
||||
|
|
|
@ -422,7 +422,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
|
||||
int64_t memoryFormat;
|
||||
if (!clone.getMemoryFormat().getType().isa<Torch::NoneType>() &&
|
||||
if (!isa<Torch::NoneType>(clone.getMemoryFormat().getType()) &&
|
||||
(!matchPattern(clone.getMemoryFormat(),
|
||||
m_TorchConstantInt(&memoryFormat)) ||
|
||||
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
|
||||
|
@ -434,24 +434,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return payloadArgs[0];
|
||||
}
|
||||
if (auto bitwiseAndTensor = dyn_cast<AtenBitwiseAndTensorOp>(op)) {
|
||||
if (bitwiseAndTensor.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
if (isa<mlir::FloatType>(
|
||||
cast<ValueTensorType>(bitwiseAndTensor.getType()).getDtype())) {
|
||||
bitwiseAndTensor.emitError(
|
||||
"Bitwise_And does not support floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Type dtype = converter->convertType(bitwiseAndTensor.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype = cast<RankedTensorType>(
|
||||
converter->convertType(bitwiseAndTensor.getType()))
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
return b.create<arith::AndIOp>(loc, lhs, rhs);
|
||||
}
|
||||
if (auto bitwiseAndScalar = dyn_cast<AtenBitwiseAndScalarOp>(op)) {
|
||||
Type dtype = converter->convertType(bitwiseAndScalar.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype = cast<RankedTensorType>(
|
||||
converter->convertType(bitwiseAndScalar.getType()))
|
||||
.getElementType();
|
||||
if (!isa<mlir::IntegerType>(dtype)) {
|
||||
bitwiseAndScalar.emitError(
|
||||
|
@ -469,32 +467,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return b.create<arith::AndIOp>(loc, self, other);
|
||||
}
|
||||
if (auto bitwiseOrTensor = dyn_cast<AtenBitwiseOrTensorOp>(op)) {
|
||||
if (bitwiseOrTensor.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
if (isa<mlir::FloatType>(
|
||||
cast<ValueTensorType>(bitwiseOrTensor.getType()).getDtype())) {
|
||||
bitwiseOrTensor.emitError(
|
||||
"Bitwise_Or does not support floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Type dtype = converter->convertType(bitwiseOrTensor.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype = cast<RankedTensorType>(
|
||||
converter->convertType(bitwiseOrTensor.getType()))
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
return b.create<arith::OrIOp>(loc, lhs, rhs);
|
||||
}
|
||||
if (auto bitwiseXorTensor = dyn_cast<AtenBitwiseXorTensorOp>(op)) {
|
||||
if (bitwiseXorTensor.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
if (isa<mlir::FloatType>(
|
||||
cast<ValueTensorType>(bitwiseXorTensor.getType()).getDtype())) {
|
||||
bitwiseXorTensor.emitError(
|
||||
"Bitwise_Xor does not support floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Type dtype = converter->convertType(bitwiseXorTensor.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype = cast<RankedTensorType>(
|
||||
converter->convertType(bitwiseXorTensor.getType()))
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
|
@ -502,8 +496,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
if (auto bitwiseRightShiftTensor =
|
||||
dyn_cast<AtenBitwiseRightShiftTensorOp>(op)) {
|
||||
Type dtype = converter->convertType(bitwiseRightShiftTensor.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype = cast<RankedTensorType>(
|
||||
converter->convertType(bitwiseRightShiftTensor.getType()))
|
||||
.getElementType();
|
||||
if (!isa<mlir::IntegerType>(dtype)) {
|
||||
bitwiseRightShiftTensor.emitError(
|
||||
|
@ -516,8 +510,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
if (auto bitwiseLeftShiftTensor =
|
||||
dyn_cast<AtenBitwiseLeftShiftTensorOp>(op)) {
|
||||
Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type dtype = cast<RankedTensorType>(
|
||||
converter->convertType(bitwiseLeftShiftTensor.getType()))
|
||||
.getElementType();
|
||||
if (!isa<mlir::IntegerType>(dtype)) {
|
||||
bitwiseLeftShiftTensor.emitError(
|
||||
|
@ -557,7 +551,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return createEqual(b, loc, floatDtype, self, zero);
|
||||
}
|
||||
if (isa<AtenAbsOp>(op)) {
|
||||
if (payloadArgs[0].getType().isa<IntegerType>())
|
||||
if (isa<IntegerType>(payloadArgs[0].getType()))
|
||||
return b.create<math::AbsIOp>(loc, payloadArgs[0]);
|
||||
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
|
||||
}
|
||||
|
@ -653,20 +647,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return b.create<arith::SelectOp>(loc, cmp, arg, zeroPoint);
|
||||
}
|
||||
if (auto round = dyn_cast<AtenRoundOp>(op)) {
|
||||
if (!round.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
if (!isa<mlir::FloatType>(
|
||||
cast<ValueTensorType>(round.getType()).getDtype())) {
|
||||
round.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
return b.create<math::RoundEvenOp>(loc, payloadArgs[0]);
|
||||
}
|
||||
if (auto prelu = dyn_cast<AtenPreluOp>(op)) {
|
||||
if (!prelu.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
if (!isa<mlir::FloatType>(
|
||||
cast<ValueTensorType>(prelu.getType()).getDtype())) {
|
||||
prelu.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -685,10 +675,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return b.create<arith::AddFOp>(loc, positivePart, scaledNegativePart);
|
||||
}
|
||||
if (auto gelu = dyn_cast<AtenGeluOp>(op)) {
|
||||
if (!gelu.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
if (!isa<mlir::FloatType>(
|
||||
cast<ValueTensorType>(gelu.getType()).getDtype())) {
|
||||
gelu.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -732,10 +720,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return nullptr;
|
||||
}
|
||||
if (auto geluBackward = dyn_cast<AtenGeluBackwardOp>(op)) {
|
||||
if (!geluBackward.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
if (!isa<mlir::FloatType>(
|
||||
cast<ValueTensorType>(geluBackward.getType()).getDtype())) {
|
||||
geluBackward.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -770,10 +756,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
if (auto hardtanhBackward = dyn_cast<AtenHardtanhBackwardOp>(op)) {
|
||||
AtenHardtanhBackwardOp::Adaptor adaptor(operands);
|
||||
if (!hardtanhBackward.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
if (!isa<mlir::FloatType>(
|
||||
cast<ValueTensorType>(hardtanhBackward.getType()).getDtype())) {
|
||||
hardtanhBackward.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -967,10 +951,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
|
||||
if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) {
|
||||
if (!pow.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
if (!isa<mlir::FloatType>(
|
||||
cast<ValueTensorType>(pow.getType()).getDtype())) {
|
||||
pow.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1047,10 +1029,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
|
||||
if (auto lerp = dyn_cast<AtenLerpTensorOp>(op)) {
|
||||
if (!lerp.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
if (!isa<mlir::FloatType>(
|
||||
cast<ValueTensorType>(lerp.getType()).getDtype())) {
|
||||
lerp.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1064,8 +1044,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
if (auto minimum = dyn_cast<AtenMinimumOp>(op)) {
|
||||
Type dtype = cast<BaseTensorType>(minimum.getType()).getDtype();
|
||||
Type elemTy = converter->convertType(minimum.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type elemTy =
|
||||
cast<RankedTensorType>(converter->convertType(minimum.getType()))
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy);
|
||||
|
@ -1074,8 +1054,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
if (auto maximum = dyn_cast<AtenMaximumOp>(op)) {
|
||||
Type dtype = cast<BaseTensorType>(maximum.getType()).getDtype();
|
||||
Type elemTy = converter->convertType(maximum.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type elemTy =
|
||||
cast<RankedTensorType>(converter->convertType(maximum.getType()))
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy);
|
||||
|
@ -1086,8 +1066,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
AtenClampOp::Adaptor adaptor(operands);
|
||||
auto min = adaptor.getMin();
|
||||
auto max = adaptor.getMax();
|
||||
if (min.getType().isa<Torch::OptionalType>() ||
|
||||
max.getType().isa<Torch::OptionalType>()) {
|
||||
if (isa<Torch::OptionalType>(min.getType()) ||
|
||||
isa<Torch::OptionalType>(max.getType())) {
|
||||
clamp.emitError("unimplemented: runtime optional type");
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1125,9 +1105,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
};
|
||||
|
||||
auto result = payloadArgs[0];
|
||||
if (!min.getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(min.getType()))
|
||||
result = cmpSelect(result, min, /*getMax=*/false);
|
||||
if (!max.getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(max.getType()))
|
||||
result = cmpSelect(result, max, /*getMax=*/true);
|
||||
return result;
|
||||
}
|
||||
|
@ -1135,8 +1115,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
AtenClampTensorOp::Adaptor adaptor(operands);
|
||||
auto min = adaptor.getMin();
|
||||
auto max = adaptor.getMax();
|
||||
if (min.getType().isa<Torch::OptionalType>() ||
|
||||
max.getType().isa<Torch::OptionalType>()) {
|
||||
if (isa<Torch::OptionalType>(min.getType()) ||
|
||||
isa<Torch::OptionalType>(max.getType())) {
|
||||
clampTensor.emitError("unimplemented: runtime optional type");
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1145,7 +1125,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
.getElementType();
|
||||
bool isMinNone = true;
|
||||
auto result = payloadArgs[0];
|
||||
if (!min.getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(min.getType())) {
|
||||
isMinNone = false;
|
||||
auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
Value pred;
|
||||
|
@ -1163,7 +1143,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
result = b.create<arith::SelectOp>(loc, pred, minPromoted, result);
|
||||
}
|
||||
if (!max.getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(max.getType())) {
|
||||
max = isMinNone ? payloadArgs[1] : payloadArgs[2];
|
||||
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
|
||||
Value pred;
|
||||
|
@ -1252,8 +1232,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return b.create<arith::DivFOp>(loc, self, other);
|
||||
}
|
||||
if (auto remScalar = dyn_cast<AtenRemainderScalarOp>(op)) {
|
||||
Type newResultType = converter->convertType(remScalar.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type newResultType =
|
||||
cast<RankedTensorType>(converter->convertType(remScalar.getType()))
|
||||
.getElementType();
|
||||
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
|
||||
|
@ -1272,8 +1252,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return result;
|
||||
}
|
||||
if (auto remTensor = dyn_cast<AtenRemainderTensorOp>(op)) {
|
||||
Type newResultType = converter->convertType(remTensor.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type newResultType =
|
||||
cast<RankedTensorType>(converter->convertType(remTensor.getType()))
|
||||
.getElementType();
|
||||
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
|
||||
|
@ -1292,8 +1272,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return result;
|
||||
}
|
||||
if (auto fmod = dyn_cast<AtenFmodTensorOp>(op)) {
|
||||
Type newResultType = converter->convertType(fmod.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type newResultType =
|
||||
cast<RankedTensorType>(converter->convertType(fmod.getType()))
|
||||
.getElementType();
|
||||
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
|
||||
|
@ -1420,8 +1400,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
|
||||
if (auto bitwiseNot = dyn_cast<AtenBitwiseNotOp>(op)) {
|
||||
Type elementType = converter->convertType(bitwiseNot.getType())
|
||||
.cast<RankedTensorType>()
|
||||
Type elementType =
|
||||
cast<RankedTensorType>(converter->convertType(bitwiseNot.getType()))
|
||||
.getElementType();
|
||||
if (isa<mlir::FloatType>(elementType)) {
|
||||
bitwiseNot.emitError("Bitwise_Not does not support floating point dtype");
|
||||
|
@ -1607,10 +1587,9 @@ public:
|
|||
|
||||
Location loc = op->getLoc();
|
||||
auto tensorOperands = llvm::to_vector<6>(llvm::make_filter_range(
|
||||
operands, [](Value v) { return v.getType().isa<RankedTensorType>(); }));
|
||||
auto resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
operands, [](Value v) { return isa<RankedTensorType>(v.getType()); }));
|
||||
auto resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
bool hadErrorCreatingPayload = false;
|
||||
Value generic = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||
rewriter, loc, tensorOperands, resultType.getElementType(),
|
||||
|
@ -1657,7 +1636,7 @@ public:
|
|||
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
||||
|
||||
// TODO: Incorporate the weight argument.
|
||||
if (!weight.getType().isa<mlir::torch::Torch::NoneType>())
|
||||
if (!isa<mlir::torch::Torch::NoneType>(weight.getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented, the weight operand is not incorporated.");
|
||||
|
||||
|
@ -1672,9 +1651,8 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "expected input and target to be rank <= 2");
|
||||
}
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
Type elementType = resultType.getElementType();
|
||||
|
||||
Value zeroVal = rewriter.create<arith::ConstantOp>(
|
||||
|
@ -1948,7 +1926,7 @@ public:
|
|||
Value input = adaptor.getSelf();
|
||||
Value target = adaptor.getTarget();
|
||||
Value weight = adaptor.getWeight();
|
||||
bool weightIsNone = op.getWeight().getType().isa<Torch::NoneType>();
|
||||
bool weightIsNone = isa<Torch::NoneType>(op.getWeight().getType());
|
||||
Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.getIgnoreIndex());
|
||||
Value totalWeight = adaptor.getTotalWeight();
|
||||
|
||||
|
@ -2069,9 +2047,8 @@ public:
|
|||
})
|
||||
->getResult(0);
|
||||
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, gradInput);
|
||||
return success();
|
||||
}
|
||||
|
@ -2214,9 +2191,8 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(TensorStaticInfoCastOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
|
||||
adaptor.getOperand());
|
||||
return success();
|
||||
|
@ -2243,7 +2219,7 @@ public:
|
|||
if (succeeded(checkNotNone(rewriter, op, eps)))
|
||||
handleEps = true;
|
||||
|
||||
if (handleEps && !eps.getType().isa<mlir::FloatType>()) {
|
||||
if (handleEps && !isa<mlir::FloatType>(eps.getType())) {
|
||||
op.emitError("Logit does not support non-floating point type");
|
||||
return failure();
|
||||
}
|
||||
|
@ -2317,9 +2293,8 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(AtenIntReprOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
|
||||
adaptor.getSelf());
|
||||
return success();
|
||||
|
@ -2362,8 +2337,8 @@ public:
|
|||
zeropoint = converter->materializeTargetConversion(
|
||||
rewriter, loc, converter->convertType(zeropoint.getType()), zeropoint);
|
||||
|
||||
auto resultType = converter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto resultType = cast<RankedTensorType>(
|
||||
converter->convertType(op->getResult(0).getType()));
|
||||
|
||||
llvm::SmallVector<Value> dynSizes;
|
||||
for (auto [index, dim] : llvm::enumerate(resultType.getShape())) {
|
||||
|
@ -2553,9 +2528,8 @@ public:
|
|||
return res;
|
||||
};
|
||||
|
||||
auto resultType = getTypeConverter()
|
||||
->convertType(op.getResult().getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op.getResult().getType()));
|
||||
SmallVector<Value> resultSize{};
|
||||
if (resultType.isDynamicDim(0))
|
||||
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
|
||||
|
@ -2675,7 +2649,7 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
|
|||
SmallVector<Value> scaleValues,
|
||||
std::string coordStr) {
|
||||
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
auto inputRank = inputType.getRank();
|
||||
|
||||
SmallVector<Value> indices;
|
||||
|
@ -2725,7 +2699,7 @@ static Value BilinearInterpolate(OpBuilder &b,
|
|||
SmallVector<Value> scaleValues,
|
||||
std::string coordStr) {
|
||||
unsigned dimOffset = 2;
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
auto inputRank = inputType.getRank();
|
||||
|
||||
Value cstOneEps =
|
||||
|
@ -2877,7 +2851,7 @@ public:
|
|||
|
||||
Location loc = op->getLoc();
|
||||
Value input = adaptor.getInput();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
auto inputRank = inputType.getRank();
|
||||
if (mode.substr(0, 8) == "bilinear" && inputRank != 4)
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -2893,7 +2867,7 @@ public:
|
|||
loc, rewriter.getIntegerType(64), inputSize));
|
||||
}
|
||||
|
||||
if (!op.getScaleFactor().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getScaleFactor().getType())) {
|
||||
bool recompScale;
|
||||
if (!matchPattern(op.getRecomputeScaleFactor(),
|
||||
m_TorchConstantBool(&recompScale)))
|
||||
|
|
|
@ -52,7 +52,7 @@ Value torch_to_linalg::getPaddedTensor(
|
|||
Value torch_to_linalg::getZeroPaddedTensor(
|
||||
Operation *op, OpBuilder &b, Value &input,
|
||||
SmallVectorImpl<int64_t> &paddingInts) {
|
||||
assert(input.getType().isa<RankedTensorType>() &&
|
||||
assert(isa<RankedTensorType>(input.getType()) &&
|
||||
"input must be RankedTensorType");
|
||||
Location loc = op->getLoc();
|
||||
Value c0 = b.create<arith::ConstantOp>(
|
||||
|
@ -67,7 +67,7 @@ Value torch_to_linalg::getZeroPaddedTensor(
|
|||
Value torch_to_linalg::getDynamicZeroPaddedTensor(
|
||||
Operation *op, OpBuilder &b, Value &input, SmallVectorImpl<Value> &padding,
|
||||
int unpaddedDims, Value pad) {
|
||||
assert(input.getType().isa<RankedTensorType>() &&
|
||||
assert(isa<RankedTensorType>(input.getType()) &&
|
||||
"input must be RankedTensorType");
|
||||
unsigned int inRank = cast<RankedTensorType>(input.getType()).getRank();
|
||||
Location loc = op->getLoc();
|
||||
|
|
|
@ -252,7 +252,7 @@ public:
|
|||
// "block" arguments
|
||||
for (const auto &barg : enumerate(op.getRegion().front().getArguments())) {
|
||||
Value to = block->getArgument(barg.index());
|
||||
if (to.getType().isa<mlir::IndexType>())
|
||||
if (isa<mlir::IndexType>(to.getType()))
|
||||
to =
|
||||
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI64Type(), to);
|
||||
Type targetType = to.getType();
|
||||
|
|
|
@ -146,9 +146,9 @@ public:
|
|||
if (!selfType) {
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
}
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
auto outType = cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
|
||||
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
|
||||
return success();
|
||||
|
@ -203,9 +203,9 @@ public:
|
|||
auto selfTy = cast<TensorType>(self.getType());
|
||||
if (!selfTy)
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
auto resultTy = cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
|
||||
if (isa<mlir::FloatType>(resultTy.getElementType())) {
|
||||
Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy);
|
||||
|
@ -231,9 +231,9 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<TensorType>();
|
||||
auto outType = dyn_cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
|
||||
if (!outType)
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
@ -321,9 +321,9 @@ public:
|
|||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError("only Tensor types supported");
|
||||
|
||||
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
auto outTy = cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
|
||||
|
@ -354,9 +354,9 @@ public:
|
|||
if (!lhsType)
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
||||
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
TensorType outType = cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat()) {
|
||||
|
@ -607,9 +607,9 @@ public:
|
|||
if (!lhsTy)
|
||||
return op.emitError("lhs must be a ranked tensor type");
|
||||
|
||||
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
TensorType outType = cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
Type outElemTy = outType.getElementType();
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
||||
if (!rhsTy) {
|
||||
|
@ -917,9 +917,9 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
|||
if (!lhsType)
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
||||
auto outType = OpConversionPattern<AtenPowTensorScalarOp>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
auto outType = cast<TensorType>(
|
||||
OpConversionPattern<AtenPowTensorScalarOp>::getTypeConverter()
|
||||
->convertType(op.getType()));
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat()) {
|
||||
|
@ -1421,9 +1421,9 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
|
||||
// Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp.
|
||||
SmallVector<APFloat> zeroConstVec(
|
||||
numFeatureDimSize, APFloat::getZero(inputTy.getElementType()
|
||||
.cast<mlir::FloatType>()
|
||||
.getFloatSemantics()));
|
||||
numFeatureDimSize,
|
||||
APFloat::getZero(
|
||||
cast<mlir::FloatType>(inputTy.getElementType()).getFloatSemantics()));
|
||||
SmallVector<APFloat> oneConstVec(
|
||||
numFeatureDimSize,
|
||||
APFloat(
|
||||
|
@ -1633,9 +1633,8 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
|||
Location loc = op->getLoc();
|
||||
|
||||
// Get element type of resultType as dtype
|
||||
auto outType = this->getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto outType = cast<RankedTensorType>(
|
||||
this->getTypeConverter()->convertType(op.getType()));
|
||||
auto dtype = outType.getElementType();
|
||||
if (!isa<mlir::IntegerType>(dtype) && !isa<mlir::FloatType>(dtype)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1678,7 +1677,7 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
|
|||
AtenConstantPadNdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
auto selfElemTy = selfTy.getElementType();
|
||||
int64_t rank = selfTy.getRank();
|
||||
|
||||
|
@ -2029,7 +2028,7 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
|||
|
||||
Value self = adaptor.getSelf();
|
||||
|
||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
if (!selfTy.hasStaticShape()) {
|
||||
return op->emitError("dynamic shaped input is not supported");
|
||||
}
|
||||
|
@ -2062,7 +2061,7 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
|||
cmpTypeAttr);
|
||||
|
||||
auto resTy =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
auto bcastTy = resTy.clone(rewriter.getI1Type());
|
||||
auto bcastAttr = rewriter.getDenseI64ArrayAttr({selfRank - 2, selfRank - 1});
|
||||
|
@ -2071,15 +2070,15 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
|||
|
||||
auto resElemTy = resTy.getElementType();
|
||||
Value zeroTensor;
|
||||
if (resElemTy.isa<mlir::FloatType>()) {
|
||||
if (isa<mlir::FloatType>(resElemTy)) {
|
||||
auto constAttr = SplatElementsAttr::get(
|
||||
resTy, llvm::APFloat::getZero(
|
||||
resElemTy.cast<FloatType>().getFloatSemantics(), false));
|
||||
cast<FloatType>(resElemTy).getFloatSemantics(), false));
|
||||
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
|
||||
} else if (resElemTy.isa<mlir::IntegerType>()) {
|
||||
} else if (isa<mlir::IntegerType>(resElemTy)) {
|
||||
auto constAttr = SplatElementsAttr::get(
|
||||
resTy,
|
||||
llvm::APInt::getZero(resElemTy.cast<mlir::IntegerType>().getWidth()));
|
||||
llvm::APInt::getZero(cast<mlir::IntegerType>(resElemTy).getWidth()));
|
||||
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
|
||||
} else {
|
||||
return op.emitError("element type is not float or integer");
|
||||
|
|
|
@ -157,8 +157,8 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
|||
Value builtinTypeStart = adaptor.getStart();
|
||||
Value builtinTypeEnd = adaptor.getEnd();
|
||||
|
||||
if (torchTypeStart.getType().isa<OptionalType>() ||
|
||||
torchTypeEnd.getType().isa<OptionalType>())
|
||||
if (isa<OptionalType>(torchTypeStart.getType()) ||
|
||||
isa<OptionalType>(torchTypeEnd.getType()))
|
||||
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
|
||||
|
||||
int64_t step;
|
||||
|
@ -349,11 +349,11 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "offsets must be a vector with static shape equal to 1");
|
||||
|
||||
if (!op.getPaddingIdx().getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(op.getPaddingIdx().getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: padding_idx should be none");
|
||||
|
||||
if (!op.getPerSampleWeights().getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(op.getPerSampleWeights().getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: per_sample_weights should be none");
|
||||
|
||||
|
@ -453,25 +453,22 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
|
|||
loc, getTypeConverter()->convertType(op.getType(0)),
|
||||
stablehloReduceOp.getResult(0), outShapeTensor);
|
||||
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(1).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(1).getType()));
|
||||
Value resultB =
|
||||
createInitialValueForGatherScatterOp(op, resultType, rewriter);
|
||||
if (!resultB)
|
||||
return failure();
|
||||
|
||||
resultType = getTypeConverter()
|
||||
->convertType(op->getResult(2).getType())
|
||||
.cast<RankedTensorType>();
|
||||
resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(2).getType()));
|
||||
Value resultC =
|
||||
createInitialValueForGatherScatterOp(op, resultType, rewriter);
|
||||
if (!resultC)
|
||||
return failure();
|
||||
|
||||
resultType = getTypeConverter()
|
||||
->convertType(op->getResult(3).getType())
|
||||
.cast<RankedTensorType>();
|
||||
resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(3).getType()));
|
||||
Value resultD =
|
||||
createInitialValueForGatherScatterOp(op, resultType, rewriter);
|
||||
if (!resultD)
|
||||
|
@ -612,9 +609,8 @@ LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
|
|||
|
||||
auto input = adaptor.getSelf();
|
||||
|
||||
RankedTensorType resultType =
|
||||
typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op->getResult(0).getType()));
|
||||
|
||||
SmallVector<Value> resultShape;
|
||||
SmallVector<Value> offsets;
|
||||
|
|
|
@ -350,9 +350,9 @@ public:
|
|||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
op,
|
||||
ConvertAtenOp<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>(),
|
||||
cast<RankedTensorType>(
|
||||
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType())),
|
||||
output);
|
||||
|
||||
return success();
|
||||
|
@ -730,9 +730,8 @@ public:
|
|||
// If transposed is set to true,
|
||||
// the weight shape changes to [IC, (OC//G), KH, KW]
|
||||
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||
auto outTy = getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
auto outTy =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
if (!inputTy || !weightTy || !outTy) {
|
||||
return op.emitError("input, weight and output must be ranked tensors");
|
||||
}
|
||||
|
|
|
@ -216,10 +216,10 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
auto *secondIdxArg = std::next(secondValArg);
|
||||
|
||||
stablehlo::ComparisonTypeAttr compareTypeAttr;
|
||||
if (inputTy.getElementType().isa<mlir::FloatType>()) {
|
||||
if (isa<mlir::FloatType>(inputTy.getElementType())) {
|
||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
|
||||
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
|
||||
} else if (isa<mlir::IntegerType>(inputTy.getElementType())) {
|
||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
||||
}
|
||||
|
@ -395,9 +395,8 @@ public:
|
|||
RankedTensorType inputTy = cast<RankedTensorType>(input.getType());
|
||||
Type inputElemTy = inputTy.getElementType();
|
||||
int64_t inputRank = inputTy.getRank();
|
||||
RankedTensorType outTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
RankedTensorType outTy = cast<RankedTensorType>(
|
||||
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType()));
|
||||
auto outShape = outTy.getShape();
|
||||
|
||||
if (inputRank <= Dim) {
|
||||
|
|
|
@ -242,10 +242,10 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
auto *secondIdxArg = std::next(secondValArg);
|
||||
|
||||
stablehlo::ComparisonTypeAttr compareTypeAttr;
|
||||
if (inputTy.getElementType().isa<mlir::FloatType>()) {
|
||||
if (isa<mlir::FloatType>(inputTy.getElementType())) {
|
||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
|
||||
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
|
||||
} else if (isa<mlir::IntegerType>(inputTy.getElementType())) {
|
||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
||||
}
|
||||
|
@ -535,12 +535,10 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
|||
"AtenMaxDimOp to StableHLO");
|
||||
}
|
||||
|
||||
RankedTensorType valResultType = getTypeConverter()
|
||||
->convertType(op.getResult(0).getType())
|
||||
.template cast<RankedTensorType>();
|
||||
RankedTensorType idxResultType = getTypeConverter()
|
||||
->convertType(op.getResult(1).getType())
|
||||
.template cast<RankedTensorType>();
|
||||
RankedTensorType valResultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op.getResult(0).getType()));
|
||||
RankedTensorType idxResultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op.getResult(1).getType()));
|
||||
Type idxElementType = idxResultType.getElementType();
|
||||
if (!isa<mlir::IntegerType>(idxElementType)) {
|
||||
return op.emitError("Aten.max.dim needs integer-like result");
|
||||
|
@ -636,9 +634,8 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
auto outTy = getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<RankedTensorType>();
|
||||
auto outTy =
|
||||
dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
|
|
|
@ -271,7 +271,7 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
||||
auto getOptionalVal = [&](Value val) -> std::optional<Value> {
|
||||
if (val.getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(val.getType())) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return val;
|
||||
|
@ -451,7 +451,7 @@ template <>
|
|||
LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite(
|
||||
PrimsSplitDimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getA().getType());
|
||||
if (!selfType) {
|
||||
return op.emitError("only tensor types are currently supported");
|
||||
}
|
||||
|
|
|
@ -292,7 +292,7 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc,
|
|||
arith::CmpIPredicate predicate = isDescending ? ge : le;
|
||||
compareOp = rewriter.create<arith::CmpIOp>(
|
||||
loc, predicate, block->getArgument(0), block->getArgument(1));
|
||||
} else if (elementTypes[0].isa<mlir::FloatType>()) {
|
||||
} else if (isa<mlir::FloatType>(elementTypes[0])) {
|
||||
// Case for using arith::CmpFOp.
|
||||
arith::CmpFPredicate predicate =
|
||||
isDescending ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OLE;
|
||||
|
@ -349,8 +349,8 @@ public:
|
|||
b.create<TMTensor::YieldOp>(loc, updatesElement);
|
||||
});
|
||||
|
||||
auto resultType = typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto resultType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op->getResult(0).getType()));
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
|
||||
return success();
|
||||
}
|
||||
|
@ -381,7 +381,7 @@ public:
|
|||
// Check whether the input is a 1-d tensor of integer type or not.
|
||||
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||
if (inputType.getRank() != 1 ||
|
||||
!inputType.getElementType().isa<mlir::IntegerType>())
|
||||
!isa<mlir::IntegerType>(inputType.getElementType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"Input tensor has to be a one-dimensional tensor of integer type.");
|
||||
|
@ -395,7 +395,7 @@ public:
|
|||
"Unimplemented: Integer width not equal to 64 are not supported.");
|
||||
|
||||
// TODO: Incorporate the weight argument.
|
||||
if (!weights.getType().isa<mlir::torch::Torch::NoneType>())
|
||||
if (!isa<mlir::torch::Torch::NoneType>(weights.getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: the weights operand is not incorporated.");
|
||||
|
||||
|
@ -439,8 +439,8 @@ public:
|
|||
indices = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
|
||||
|
||||
auto resultType = typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto resultType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op->getResult(0).getType()));
|
||||
Type resultElemType = resultType.getElementType();
|
||||
|
||||
SmallVector<Value, 1> inputSizeDynamic =
|
||||
|
@ -686,8 +686,8 @@ public:
|
|||
auto valuesType = cast<ValueTensorType>(values.getType());
|
||||
int64_t inputRank = inputType.getSizes().size();
|
||||
auto valuesTensorType = cast<BaseTensorType>(op.getValues().getType());
|
||||
auto resultType = typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto resultType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op->getResult(0).getType()));
|
||||
|
||||
if (!valuesTensorType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -823,10 +823,10 @@ public:
|
|||
Value inputElement) {
|
||||
Value yieldValue = valuesElement;
|
||||
if (accumulate) {
|
||||
if (inputElement.getType().isa<mlir::IntegerType>()) {
|
||||
if (isa<mlir::IntegerType>(inputElement.getType())) {
|
||||
yieldValue =
|
||||
b.create<arith::AddIOp>(loc, inputElement, valuesElement);
|
||||
} else if (inputElement.getType().isa<mlir::FloatType>()) {
|
||||
} else if (isa<mlir::FloatType>(inputElement.getType())) {
|
||||
yieldValue =
|
||||
b.create<arith::AddFOp>(loc, inputElement, valuesElement);
|
||||
} else {
|
||||
|
@ -1042,10 +1042,10 @@ public:
|
|||
[&](OpBuilder &b, Location loc, Value valuesElement,
|
||||
Value inputElement) {
|
||||
Value yieldValue = valuesElement;
|
||||
if (inputElement.getType().isa<mlir::IntegerType>()) {
|
||||
if (isa<mlir::IntegerType>(inputElement.getType())) {
|
||||
yieldValue =
|
||||
b.create<arith::AddIOp>(loc, inputElement, valuesElement);
|
||||
} else if (inputElement.getType().isa<mlir::FloatType>()) {
|
||||
} else if (isa<mlir::FloatType>(inputElement.getType())) {
|
||||
yieldValue =
|
||||
b.create<arith::AddFOp>(loc, inputElement, valuesElement);
|
||||
} else {
|
||||
|
@ -1204,33 +1204,33 @@ public:
|
|||
Value result;
|
||||
if (reduceEnum == torch_upstream::ReductionType::SUM ||
|
||||
reduceEnum == torch_upstream::ReductionType::MEAN) {
|
||||
if (update.getType().isa<mlir::IntegerType>()) {
|
||||
if (isa<mlir::IntegerType>(update.getType())) {
|
||||
result = b.create<arith::AddIOp>(loc, update, current);
|
||||
} else if (update.getType().isa<mlir::FloatType>()) {
|
||||
} else if (isa<mlir::FloatType>(update.getType())) {
|
||||
result = b.create<arith::AddFOp>(loc, update, current);
|
||||
} else {
|
||||
llvm_unreachable("Only integer/float types supported!");
|
||||
}
|
||||
} else if (reduceEnum == torch_upstream::ReductionType::PROD) {
|
||||
if (update.getType().isa<mlir::IntegerType>()) {
|
||||
if (isa<mlir::IntegerType>(update.getType())) {
|
||||
result = b.create<arith::MulIOp>(loc, update, current);
|
||||
} else if (update.getType().isa<mlir::FloatType>()) {
|
||||
} else if (isa<mlir::FloatType>(update.getType())) {
|
||||
result = b.create<arith::MulFOp>(loc, update, current);
|
||||
} else {
|
||||
llvm_unreachable("Only integer/float types supported!");
|
||||
}
|
||||
} else if (reduceEnum == torch_upstream::ReductionType::MAX) {
|
||||
if (update.getType().isa<mlir::IntegerType>()) {
|
||||
if (isa<mlir::IntegerType>(update.getType())) {
|
||||
result = b.create<arith::MaxSIOp>(loc, update, current);
|
||||
} else if (update.getType().isa<mlir::FloatType>()) {
|
||||
} else if (isa<mlir::FloatType>(update.getType())) {
|
||||
result = b.create<arith::MaximumFOp>(loc, update, current);
|
||||
} else {
|
||||
llvm_unreachable("Only integer/float types supported!");
|
||||
}
|
||||
} else if (reduceEnum == torch_upstream::ReductionType::MIN) {
|
||||
if (update.getType().isa<mlir::IntegerType>()) {
|
||||
if (isa<mlir::IntegerType>(update.getType())) {
|
||||
result = b.create<arith::MinSIOp>(loc, update, current);
|
||||
} else if (update.getType().isa<mlir::FloatType>()) {
|
||||
} else if (isa<mlir::FloatType>(update.getType())) {
|
||||
result = b.create<arith::MinimumFOp>(loc, update, current);
|
||||
} else {
|
||||
llvm_unreachable("Only integer/float types supported!");
|
||||
|
@ -1285,9 +1285,8 @@ public:
|
|||
})
|
||||
.getResult()[0];
|
||||
}
|
||||
auto resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
|
||||
|
||||
return success();
|
||||
|
@ -1392,9 +1391,8 @@ public:
|
|||
|
||||
Location loc = op.getLoc();
|
||||
Value input = adaptor.getSelf();
|
||||
auto resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
Type elementType = resultType.getElementType();
|
||||
Type inputElementType =
|
||||
cast<RankedTensorType>(input.getType()).getElementType();
|
||||
|
@ -1414,7 +1412,7 @@ public:
|
|||
|
||||
int64_t inputRank = resultType.getRank();
|
||||
Value dtype = op.getDtype();
|
||||
if (!dtype.getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(dtype.getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unsupported: dtype argument not supported");
|
||||
|
||||
|
@ -1444,7 +1442,7 @@ public:
|
|||
rewriter, loc, input, output, acc, dim, /*inclusive=*/true,
|
||||
[](OpBuilder &b, Location loc, Value input, Value acc) {
|
||||
Value sum =
|
||||
(input.getType().isa<mlir::FloatType>()
|
||||
(isa<mlir::FloatType>(input.getType())
|
||||
? b.create<arith::AddFOp>(loc, input, acc)->getResult(0)
|
||||
: b.create<arith::AddIOp>(loc, input, acc)->getResult(0));
|
||||
b.create<TMTensor::YieldOp>(loc, sum);
|
||||
|
@ -1472,7 +1470,7 @@ public:
|
|||
cast<ShapedType>(adaptor.getQuery().getType()).getElementType();
|
||||
|
||||
// Verify inputs (only support defaults)
|
||||
if (!mask.getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(mask.getType()))
|
||||
return rewriter.notifyMatchFailure(op.getLoc(),
|
||||
"attention masking not supported");
|
||||
double dropout;
|
||||
|
@ -1483,7 +1481,7 @@ public:
|
|||
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op.getLoc(), "causal attention masking not supported");
|
||||
if (!scale.getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(scale.getType())) {
|
||||
double scaleFloat;
|
||||
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
|
||||
scaleFloat != 1.0)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
////
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
@ -47,7 +47,7 @@ public:
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"Only Tensor types supported in TOSA");
|
||||
|
||||
if (selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||
if (isa<mlir::FloatType>(selfTy.getElementType())) {
|
||||
rewriter.replaceOpWithNewOp<TosaOpT>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
|
@ -99,9 +99,9 @@ public:
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"Only Tensor types supported in TOSA");
|
||||
|
||||
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
auto outTy = cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
|
||||
auto binaryOp =
|
||||
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
|
||||
|
@ -248,9 +248,9 @@ public:
|
|||
}
|
||||
|
||||
// Get output type: tensor<i32/i64/f32>
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
auto outType = cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat()) {
|
||||
|
@ -373,9 +373,9 @@ public:
|
|||
std::is_same<AtenOpT, AtenLtScalarOp>());
|
||||
|
||||
// Promote lhs and rhs dtypes for bitwise operators.
|
||||
TensorType resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
TensorType resultTy = cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
if (isBitwiseOp) {
|
||||
lhs = tosa::promoteType(rewriter, lhs, resultTy);
|
||||
rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy);
|
||||
|
@ -416,9 +416,9 @@ public:
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"Only Tensor types supported in TOSA");
|
||||
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
auto outType = cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat())
|
||||
|
@ -444,9 +444,9 @@ public:
|
|||
}
|
||||
|
||||
if (isa<mlir::FloatType>(outElemTy) || isa<mlir::IntegerType>(outElemTy)) {
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
auto outType = cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
|
||||
auto mulOp = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
|
||||
rhsTensor, /*shift=*/0);
|
||||
|
@ -492,9 +492,9 @@ public:
|
|||
"conversion in TOSA operation");
|
||||
}
|
||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
auto outType = cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
|
||||
// auto result;
|
||||
Value result;
|
||||
|
@ -540,7 +540,7 @@ LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = cast<TensorType>(self.getType());
|
||||
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||
if (selfTy && isa<mlir::FloatType>(selfTy.getElementType())) {
|
||||
rewriter.replaceOpWithNewOp<tosa::TanhOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self);
|
||||
return success();
|
||||
|
@ -557,7 +557,7 @@ LogicalResult ConvertAtenOp<AtenSigmoidOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = cast<TensorType>(self.getType());
|
||||
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||
if (selfTy && isa<mlir::FloatType>(selfTy.getElementType())) {
|
||||
rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self);
|
||||
return success();
|
||||
|
@ -584,7 +584,7 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
|||
}
|
||||
|
||||
// Rescale the clampIn for quantized types. TBD
|
||||
if (!selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||
if (!isa<mlir::FloatType>(selfTy.getElementType())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point datatype legalization currently supported");
|
||||
}
|
||||
|
@ -604,7 +604,7 @@ LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
|
|||
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = cast<TensorType>(self.getType());
|
||||
if (!selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||
if (!isa<mlir::FloatType>(selfTy.getElementType())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point datatype legalization currently supported");
|
||||
}
|
||||
|
@ -667,9 +667,9 @@ public:
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"Only Tensor types supported in TOSA");
|
||||
|
||||
auto outputTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
auto outputTy = cast<RankedTensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
if (!outputTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor type outputs permitted for reduce_mean");
|
||||
|
@ -828,9 +828,8 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const keepdim parameter unsupported");
|
||||
|
||||
auto resultTy = getTypeConverter()
|
||||
->convertType(op.getResult().getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto resultTy = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op.getResult().getType()));
|
||||
auto outputETy = resultTy.getElementType();
|
||||
|
||||
// Create a single instance of tosa.argmax.
|
||||
|
@ -927,9 +926,9 @@ public:
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"Squeeze could not compute new shape");
|
||||
|
||||
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getResult().getType())
|
||||
.template cast<RankedTensorType>();
|
||||
auto resultTy = cast<RankedTensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getResult().getType()));
|
||||
auto resultElemTy = resultTy.getElementType();
|
||||
|
||||
auto newOutputTy = RankedTensorType::get(
|
||||
|
@ -1017,7 +1016,7 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types supported in TOSA Pow");
|
||||
|
||||
if (!selfTy.getElementType().isa<mlir::FloatType>())
|
||||
if (!isa<mlir::FloatType>(selfTy.getElementType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point datatype legalization supported");
|
||||
|
||||
|
@ -1624,9 +1623,9 @@ public:
|
|||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>(),
|
||||
cast<RankedTensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType())),
|
||||
output);
|
||||
|
||||
return success();
|
||||
|
@ -1800,9 +1799,9 @@ public:
|
|||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>(),
|
||||
cast<RankedTensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType())),
|
||||
matmulPlusBias);
|
||||
|
||||
return success();
|
||||
|
@ -1823,7 +1822,7 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types supported in TOSA Rsub");
|
||||
|
||||
if (!selfTy.getElementType().isa<mlir::FloatType>())
|
||||
if (!isa<mlir::FloatType>(selfTy.getElementType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point datatype legalization supported");
|
||||
|
||||
|
@ -1869,9 +1868,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
|||
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||
auto outputTy = getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
auto outputTy =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
if (!inputTy || !weightTy || !outputTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -2208,7 +2206,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
// Note: cudnn_enabled is not handled.
|
||||
|
||||
// FIXME: Handle training and momentum.
|
||||
if (op.getMomentum().getType().isa<Torch::NoneType>())
|
||||
if (isa<Torch::NoneType>(op.getMomentum().getType()))
|
||||
return rewriter.notifyMatchFailure(op, "Unsupported None for momentum");
|
||||
|
||||
auto meanType = dyn_cast<TensorType>(adaptor.getRunningMean().getType());
|
||||
|
@ -2312,9 +2310,9 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
// Note: cudnn_enabled is not handled.
|
||||
|
||||
// FIXME: Handle the None cases for the optional parameters.
|
||||
if (adaptor.getWeight().getType().isa<Torch::NoneType>())
|
||||
if (isa<Torch::NoneType>(adaptor.getWeight().getType()))
|
||||
return rewriter.notifyMatchFailure(op, "Unsupported None for weight");
|
||||
if (adaptor.getBias().getType().isa<Torch::NoneType>())
|
||||
if (isa<Torch::NoneType>(adaptor.getBias().getType()))
|
||||
return rewriter.notifyMatchFailure(op, "Unsupported None for bias");
|
||||
|
||||
auto weightType = cast<RankedTensorType>(adaptor.getWeight().getType());
|
||||
|
@ -2453,9 +2451,8 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
|||
ValueTensorLiteralOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
auto outputTy = getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
auto outputTy =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
// Tensors with integer types need to be converted to signless integer
|
||||
// element type. All tensors with element types other than integer can reuse
|
||||
|
@ -3122,7 +3119,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
|||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
|
||||
auto indicesType = dyn_cast<RankedTensorType>(indices.getType());
|
||||
if (!indicesType || !indicesType.getElementType().isa<IntegerType>())
|
||||
if (!indicesType || !isa<IntegerType>(indicesType.getElementType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Indices must be of integer tensor type");
|
||||
|
||||
|
@ -3632,11 +3629,11 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
|||
auto indexTorch = tensorsTorchType[i];
|
||||
// TODO add support for none index other than i==0, like (index0, None)
|
||||
// (None, index1)
|
||||
if (i == 0 && indexTorch.getType().isa<Torch::NoneType>()) {
|
||||
if (i == 0 && isa<Torch::NoneType>(indexTorch.getType())) {
|
||||
// convert None to [0,0,0]
|
||||
auto indexNext = indexTensors[i + 1];
|
||||
auto indexNextTorch = tensorsTorchType[i + 1];
|
||||
if (indexNextTorch.getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(indexNextTorch.getType())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Multiple None index is not support for now.");
|
||||
}
|
||||
|
@ -3963,8 +3960,8 @@ LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
|
|||
if (!selfType.hasStaticShape() || !otherType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types with static shape are supported");
|
||||
if (!selfType.getElementType().isa<mlir::FloatType>() ||
|
||||
!otherType.getElementType().isa<mlir::FloatType>()) {
|
||||
if (!isa<mlir::FloatType>(selfType.getElementType()) ||
|
||||
!isa<mlir::FloatType>(otherType.getElementType())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only FP element type is supported");
|
||||
}
|
||||
|
@ -4058,9 +4055,8 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||
RankedTensorType resultType =
|
||||
typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op->getResult(0).getType()));
|
||||
|
||||
// At this point all tensors should have value semantics, and hence the
|
||||
// `layout` check can be ignored.
|
||||
|
@ -4068,7 +4064,7 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
|||
// TODO: Add support for pin_memory features.
|
||||
// The pin_memory should be either `False` or `none`.
|
||||
bool pinMemory;
|
||||
if (!op.getPinMemory().getType().isa<Torch::NoneType>() &&
|
||||
if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
|
||||
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
||||
pinMemory)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -4162,10 +4158,10 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
|||
};
|
||||
|
||||
const auto isIntType =
|
||||
resultType.getElementType().dyn_cast_or_null<mlir::IntegerType>();
|
||||
dyn_cast_or_null<mlir::IntegerType>(resultType.getElementType());
|
||||
|
||||
const auto isDoubleType =
|
||||
resultType.getElementType().dyn_cast_or_null<mlir::FloatType>();
|
||||
dyn_cast_or_null<mlir::FloatType>(resultType.getElementType());
|
||||
|
||||
auto maybeResult = [&]() -> std::optional<Value> {
|
||||
// Integer output type, and start / end / range are all integers.
|
||||
|
@ -4218,9 +4214,8 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||
RankedTensorType resultType =
|
||||
typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op->getResult(0).getType()));
|
||||
|
||||
// Only supports integer operand type, because for the floating point operand
|
||||
// type result tensor has to be of type `f64` which is not supported in the
|
||||
|
@ -4323,7 +4318,7 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
|||
}
|
||||
|
||||
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
||||
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType())) {
|
||||
int64_t memoryFormat;
|
||||
if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -4336,9 +4331,8 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
|||
"memory_format is supported");
|
||||
}
|
||||
|
||||
auto resultTy = getTypeConverter()
|
||||
->convertType(op.getResult().getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto resultTy = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op.getResult().getType()));
|
||||
|
||||
Value result;
|
||||
if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.getSelf(),
|
||||
|
@ -4779,9 +4773,9 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<TensorType>();
|
||||
auto outType = dyn_cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
|
||||
if (!outType)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
|
@ -4841,9 +4835,9 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<TensorType>();
|
||||
auto outType = dyn_cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
|
||||
if (!outType || !outType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -4875,9 +4869,9 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<TensorType>();
|
||||
auto outType = dyn_cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
|
||||
if (!outType || !outType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -4947,9 +4941,9 @@ public:
|
|||
"unimplemented: only contiguous and channels last memory "
|
||||
"format is supported");
|
||||
}
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<TensorType>();
|
||||
auto outType = dyn_cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, adaptor.getSelf());
|
||||
|
||||
return success();
|
||||
|
@ -5077,8 +5071,8 @@ LogicalResult ConvertAtenOp<AtenSqrtOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"Only Tensor types supported in TOSA");
|
||||
|
||||
auto resultType = typeConverter->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
auto resultType =
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
auto elementType = resultType.getElementType();
|
||||
|
||||
if (isa<mlir::IntegerType>(selfTy.getElementType())) {
|
||||
|
|
|
@ -813,9 +813,9 @@ convertReduceProdOp(PatternRewriter &rewriter, Operation *op,
|
|||
return std::nullopt;
|
||||
|
||||
bool input_is_qtype =
|
||||
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||||
isa<mlir::quant::UniformQuantizedType>(input_type.getElementType());
|
||||
bool output_is_qtype =
|
||||
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||||
isa<mlir::quant::UniformQuantizedType>(output_type.getElementType());
|
||||
|
||||
if (input_is_qtype || output_is_qtype) {
|
||||
op->emitOpError("ConvertReduceProdOp: input/output tensor should "
|
||||
|
@ -839,9 +839,9 @@ convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
|
|||
return std::nullopt;
|
||||
|
||||
bool input_is_qtype =
|
||||
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||||
isa<mlir::quant::UniformQuantizedType>(input_type.getElementType());
|
||||
bool output_is_qtype =
|
||||
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||||
isa<mlir::quant::UniformQuantizedType>(output_type.getElementType());
|
||||
|
||||
if (input_is_qtype != output_is_qtype) {
|
||||
op->emitOpError("ConvertReduceSumOp: input/output tensor should "
|
||||
|
@ -894,9 +894,9 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
|
|||
return std::nullopt;
|
||||
|
||||
bool input_is_qtype =
|
||||
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||||
isa<mlir::quant::UniformQuantizedType>(input_type.getElementType());
|
||||
bool output_is_qtype =
|
||||
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||||
isa<mlir::quant::UniformQuantizedType>(output_type.getElementType());
|
||||
|
||||
if (input_is_qtype != output_is_qtype) {
|
||||
op->emitOpError("ConvertReduceSumOp: input/output tensor should "
|
||||
|
@ -905,7 +905,7 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
|
|||
}
|
||||
|
||||
// Only supports float type mean() if it's non-quantized
|
||||
if (!input_is_qtype && !output_type.getElementType().isa<mlir::FloatType>()) {
|
||||
if (!input_is_qtype && !isa<mlir::FloatType>(output_type.getElementType())) {
|
||||
op->emitWarning(
|
||||
"Failed convertReduceMean: input unquantized type but output element "
|
||||
"not FloatType!");
|
||||
|
|
|
@ -31,7 +31,7 @@ LogicalResult verifyLinalgCompatibleTypes(Operation *op,
|
|||
return false;
|
||||
auto tensor = dyn_cast<ValueTensorType>(type);
|
||||
return !tensor ||
|
||||
tensor.toBuiltinTensor().dyn_cast_or_null<RankedTensorType>();
|
||||
dyn_cast_or_null<RankedTensorType>(tensor.toBuiltinTensor());
|
||||
};
|
||||
|
||||
bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) &&
|
||||
|
@ -66,7 +66,7 @@ Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
|
|||
|
||||
// Generate IR: assert(dim >= 0 && dim < inputRank)
|
||||
void assertIsValidDim(OpBuilder &b, Location loc, Value dim, Value inputRank) {
|
||||
assert(dim.getType().isa<IntegerType>() &&
|
||||
assert(isa<IntegerType>(dim.getType()) &&
|
||||
"dim arg of assertIsValidDim must be integer type");
|
||||
Value cst0 =
|
||||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
|
||||
|
@ -139,12 +139,12 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
|||
}
|
||||
|
||||
Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
|
||||
assert(v.getType().isa<IntegerType>() && "must be called with integer type");
|
||||
assert(isa<IntegerType>(v.getType()) && "must be called with integer type");
|
||||
return b.create<arith::IndexCastOp>(loc, b.getIndexType(), v);
|
||||
}
|
||||
|
||||
Value castIndexToInt64(OpBuilder &b, Location loc, Value idx) {
|
||||
assert(idx.getType().isa<IndexType>() && "must be called with integer type");
|
||||
assert(isa<IndexType>(idx.getType()) && "must be called with integer type");
|
||||
return b.create<arith::IndexCastOp>(loc, b.getI64Type(), idx);
|
||||
}
|
||||
|
||||
|
@ -375,7 +375,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
|||
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value torchOptionalInt, Value builtinInt,
|
||||
Value defaultValue, Value dimSize) {
|
||||
if (torchOptionalInt.getType().isa<Torch::NoneType>())
|
||||
if (isa<Torch::NoneType>(torchOptionalInt.getType()))
|
||||
return defaultValue;
|
||||
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
|
||||
Value positiveDim =
|
||||
|
|
|
@ -149,14 +149,12 @@ static Value getScalarIntValue(Value input, Location loc,
|
|||
|
||||
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
||||
if (inputDtype.isInteger(64)) {
|
||||
auto val = valueTensorLiteralOp.getValue()
|
||||
.cast<DenseIntElementsAttr>()
|
||||
auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue())
|
||||
.getSplatValue<int64_t>();
|
||||
return rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(val));
|
||||
} else {
|
||||
auto val = valueTensorLiteralOp.getValue()
|
||||
.cast<DenseIntElementsAttr>()
|
||||
auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue())
|
||||
.getSplatValue<bool>();
|
||||
return rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(val));
|
||||
|
@ -191,8 +189,7 @@ static Value getScalarFloatValue(Value input, Location loc,
|
|||
return nullptr;
|
||||
|
||||
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
||||
auto val = valueTensorLiteralOp.getValue()
|
||||
.cast<DenseFPElementsAttr>()
|
||||
auto val = cast<DenseFPElementsAttr>(valueTensorLiteralOp.getValue())
|
||||
.getSplatValue<FloatAttr>()
|
||||
.getValueAsDouble();
|
||||
return rewriter.create<Torch::ConstantFloatOp>(
|
||||
|
@ -1946,7 +1943,7 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
|
|||
OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) {
|
||||
auto resultType = dyn_cast<ValueTensorType>(getType());
|
||||
if (resultType && resultType.hasDtype() &&
|
||||
resultType.getDtype().isa<mlir::IntegerType>()) {
|
||||
isa<mlir::IntegerType>(resultType.getDtype())) {
|
||||
return getSelf();
|
||||
}
|
||||
return {};
|
||||
|
@ -2136,7 +2133,7 @@ traceKnownSizeTensorType(Value value, std::optional<int64_t> dim) {
|
|||
// Limit the loop count to 6 to avoid indefinite compilation times from
|
||||
// unbounded IR traversals.
|
||||
for (auto idx = 0; idx < 6; ++idx) {
|
||||
if (!value || !value.getType().isa<BaseTensorType>())
|
||||
if (!value || !isa<BaseTensorType>(value.getType()))
|
||||
return failure();
|
||||
|
||||
auto tensorType = cast<BaseTensorType>(value.getType());
|
||||
|
@ -2518,7 +2515,7 @@ OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) {
|
|||
|
||||
OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
|
||||
// Constant fold int -> float conversion.
|
||||
if (auto integerAttr = adaptor.getA().dyn_cast_or_null<IntegerAttr>()) {
|
||||
if (auto integerAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getA())) {
|
||||
return FloatAttr::get(
|
||||
mlir::Float64Type::get(getContext()),
|
||||
static_cast<double>(integerAttr.getValue().getSExtValue()));
|
||||
|
@ -2535,7 +2532,7 @@ OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
|
|||
|
||||
OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) {
|
||||
// Constant fold float -> int conversion.
|
||||
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
|
||||
if (auto floatAttr = dyn_cast_or_null<FloatAttr>(adaptor.getA())) {
|
||||
return IntegerAttr::get(
|
||||
mlir::IntegerType::get(getContext(), 64),
|
||||
static_cast<int64_t>(floatAttr.getValue().convertToDouble()));
|
||||
|
@ -2549,7 +2546,7 @@ OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) {
|
|||
|
||||
OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) {
|
||||
// Constant fold float -> int conversion.
|
||||
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
|
||||
if (auto floatAttr = dyn_cast_or_null<FloatAttr>(adaptor.getA())) {
|
||||
return IntegerAttr::get(
|
||||
mlir::IntegerType::get(getContext(), 64),
|
||||
static_cast<long>(floatAttr.getValue().convertToDouble()));
|
||||
|
@ -2695,9 +2692,8 @@ LogicalResult NonValueTensorLiteralOp::inferReturnTypes(
|
|||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
auto attr = properties.as<Properties *>()
|
||||
->getValue()
|
||||
.dyn_cast_or_null<ElementsAttr>();
|
||||
auto attr =
|
||||
dyn_cast_or_null<ElementsAttr>(properties.as<Properties *>()->getValue());
|
||||
if (!attr)
|
||||
return failure();
|
||||
RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
|
||||
|
@ -2723,10 +2719,10 @@ static bool areSizesAndDtypesCompatible(BaseTensorType a, BaseTensorType b) {
|
|||
|
||||
bool NonValueTensorLiteralOp::isCompatibleReturnTypes(TypeRange inferred,
|
||||
TypeRange actual) {
|
||||
if (!actual[0].isa<BaseTensorType>())
|
||||
if (!isa<BaseTensorType>(actual[0]))
|
||||
return false;
|
||||
return areSizesAndDtypesCompatible(inferred[0].cast<BaseTensorType>(),
|
||||
actual[0].cast<BaseTensorType>());
|
||||
return areSizesAndDtypesCompatible(cast<BaseTensorType>(inferred[0]),
|
||||
cast<BaseTensorType>(actual[0]));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2737,9 +2733,8 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
|
|||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
auto attr = properties.as<Properties *>()
|
||||
->getValue()
|
||||
.dyn_cast_or_null<ElementsAttr>();
|
||||
auto attr =
|
||||
dyn_cast_or_null<ElementsAttr>(properties.as<Properties *>()->getValue());
|
||||
if (!attr)
|
||||
return failure();
|
||||
RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
|
||||
|
@ -2760,8 +2755,8 @@ OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) {
|
|||
|
||||
bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs,
|
||||
mlir::TypeRange outputs) {
|
||||
return areSizesAndDtypesCompatible(inputs[0].cast<BaseTensorType>(),
|
||||
outputs[0].cast<BaseTensorType>());
|
||||
return areSizesAndDtypesCompatible(cast<BaseTensorType>(inputs[0]),
|
||||
cast<BaseTensorType>(outputs[0]));
|
||||
}
|
||||
|
||||
void TensorStaticInfoCastOp::getCanonicalizationPatterns(
|
||||
|
@ -3072,7 +3067,7 @@ OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) {
|
|||
if (!operandType)
|
||||
return nullptr;
|
||||
if (operandType.hasDtype()) {
|
||||
bool isFloatType = operandType.getDtype().isa<mlir::FloatType>();
|
||||
bool isFloatType = isa<mlir::FloatType>(operandType.getDtype());
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 1), isFloatType);
|
||||
}
|
||||
// doesn't has dtype
|
||||
|
@ -3130,12 +3125,12 @@ void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
int64_t start;
|
||||
int64_t end;
|
||||
int64_t step;
|
||||
if (op.getStart().getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(op.getStart().getType())) {
|
||||
start = 0;
|
||||
} else if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) {
|
||||
return failure();
|
||||
}
|
||||
if (op.getEnd().getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(op.getEnd().getType())) {
|
||||
end = listElements.size();
|
||||
} else if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) {
|
||||
return failure();
|
||||
|
@ -3228,7 +3223,7 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
// things.
|
||||
Value replacement = tupleConstruct.getElements()[i];
|
||||
if (replacement.getType() != op.getType()) {
|
||||
if (op.getType().isa<BaseTensorType>()) {
|
||||
if (isa<BaseTensorType>(op.getType())) {
|
||||
replacement = rewriter.create<Torch::TensorStaticInfoCastOp>(
|
||||
op.getLoc(), op.getType(), replacement);
|
||||
} else {
|
||||
|
@ -3384,8 +3379,8 @@ using BinaryIntOperatorFn = std::function<int64_t(int64_t, int64_t)>;
|
|||
static OpFoldResult
|
||||
atenBinaryIntOperatorFoldHelper(ArrayRef<Attribute> operands,
|
||||
BinaryIntOperatorFn f) {
|
||||
auto intLhs = operands[0].dyn_cast_or_null<IntegerAttr>();
|
||||
auto intRhs = operands[1].dyn_cast_or_null<IntegerAttr>();
|
||||
auto intLhs = dyn_cast_or_null<IntegerAttr>(operands[0]);
|
||||
auto intRhs = dyn_cast_or_null<IntegerAttr>(operands[1]);
|
||||
if (!intLhs || !intRhs) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -3711,7 +3706,7 @@ OpFoldResult AtenAddOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) {
|
||||
if (isa<IntegerAttr>(adaptor.getA()) && isa<IntegerAttr>(adaptor.getB())) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
adaptor.getOperands(),
|
||||
[](int64_t a, int64_t b) -> int64_t { return a + b; });
|
||||
|
@ -3730,7 +3725,7 @@ OpFoldResult AtenMulOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) {
|
||||
if (isa<IntegerAttr>(adaptor.getA()) && isa<IntegerAttr>(adaptor.getB())) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
adaptor.getOperands(),
|
||||
[](int64_t a, int64_t b) -> int64_t { return a * b; });
|
||||
|
@ -3749,7 +3744,7 @@ OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) {
|
||||
if (isa<IntegerAttr>(adaptor.getA()) && isa<IntegerAttr>(adaptor.getB())) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
adaptor.getOperands(),
|
||||
[](int64_t a, int64_t b) -> int64_t { return a - b; });
|
||||
|
@ -3806,7 +3801,7 @@ OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) {
|
|||
if (!adaptor.getA()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto floatValue = adaptor.getA().dyn_cast_or_null<FloatAttr>();
|
||||
auto floatValue = dyn_cast_or_null<FloatAttr>(adaptor.getA());
|
||||
if (!floatValue) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -3834,7 +3829,7 @@ OpFoldResult AtenNegFloatOp::fold(FoldAdaptor adaptor) {
|
|||
if (!adaptor.getA()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto value = adaptor.getA().dyn_cast_or_null<FloatAttr>();
|
||||
auto value = dyn_cast_or_null<FloatAttr>(adaptor.getA());
|
||||
if (!value) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -4487,8 +4482,8 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
|
|||
if (getA() == getB())
|
||||
return getA();
|
||||
|
||||
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
|
||||
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
|
||||
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
|
||||
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
|
||||
if (!lhs || !rhs)
|
||||
return nullptr;
|
||||
// Torch semantics are that !torch.int is 64-bit signed.
|
||||
|
@ -4556,8 +4551,8 @@ OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) {
|
|||
if (getA() == getB())
|
||||
return getA();
|
||||
|
||||
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
|
||||
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
|
||||
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
|
||||
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
|
||||
if (!lhs || !rhs)
|
||||
return nullptr;
|
||||
// Torch semantics are that !torch.int is 64-bit signed.
|
||||
|
@ -4644,8 +4639,8 @@ LogicalResult AtenNormScalarOp::verify() {
|
|||
// Check if dtype is one of those supported by norm operation.
|
||||
// ComplexType will match any torch complex types, but each float must be
|
||||
// checked individually.
|
||||
if (!inTensorDtype.isa<mlir::ComplexType, mlir::Float16Type,
|
||||
mlir::Float32Type, mlir::Float64Type>()) {
|
||||
if (!isa<mlir::ComplexType, mlir::Float16Type, mlir::Float32Type,
|
||||
mlir::Float64Type>(inTensorDtype)) {
|
||||
return emitOpError(
|
||||
"expected a float or complex type for input tensor, but got ")
|
||||
<< inTensorDtype;
|
||||
|
|
|
@ -190,8 +190,8 @@ static bool isValidTorchDtype(Type dtype) {
|
|||
// Builtin floating point types.
|
||||
if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(dtype))
|
||||
return true;
|
||||
if (dtype.isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
|
||||
Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
|
||||
if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
|
||||
Float8E4M3FNUZType, Float8E4M3B11FNUZType>(dtype))
|
||||
return true;
|
||||
|
||||
if (isa<Torch::StringType>(dtype))
|
||||
|
@ -228,9 +228,9 @@ Type BaseTensorType::getWithSizesAndDtypeFrom(BaseTensorType other) const {
|
|||
|
||||
Type BaseTensorType::getWithSizesAndDtype(
|
||||
std::optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype) const {
|
||||
if (isa<NonValueTensorType>())
|
||||
if (mlir::isa<NonValueTensorType>(*this))
|
||||
return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype);
|
||||
if (isa<ValueTensorType>())
|
||||
if (mlir::isa<ValueTensorType>(*this))
|
||||
return ValueTensorType::get(getContext(), optionalSizes, optionalDtype);
|
||||
llvm_unreachable("not a BaseTensorType!");
|
||||
}
|
||||
|
@ -248,9 +248,9 @@ Type BaseTensorType::getWithSizesAndDtypeAndSparsity(
|
|||
}
|
||||
|
||||
ValueTensorType BaseTensorType::getWithValueSemantics() const {
|
||||
if (auto tensor = dyn_cast<NonValueTensorType>())
|
||||
if (auto tensor = mlir::dyn_cast<NonValueTensorType>(*this))
|
||||
return tensor.getWithValueSemantics();
|
||||
if (auto tensor = dyn_cast<ValueTensorType>())
|
||||
if (auto tensor = mlir::dyn_cast<ValueTensorType>(*this))
|
||||
return tensor;
|
||||
llvm_unreachable("not a BaseTensorType!");
|
||||
}
|
||||
|
|
|
@ -110,7 +110,7 @@ public:
|
|||
continue;
|
||||
auto it = typeBoundMap.find({call.getCallee(), operand.index()});
|
||||
if (it != typeBoundMap.end()) {
|
||||
if (auto valueTensorType = it->second.dyn_cast<ValueTensorType>()) {
|
||||
if (auto valueTensorType = dyn_cast<ValueTensorType>(it->second)) {
|
||||
newOperands.push_back(copyTensorToType(
|
||||
rewriter, call->getLoc(), valueTensorType, operand.value()));
|
||||
continue;
|
||||
|
@ -215,11 +215,11 @@ static LogicalResult adjustCallingConventions(func::FuncOp func,
|
|||
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
|
||||
if (func.getArgAttr(i, "torch.type_bound"))
|
||||
return false;
|
||||
if (func.getArgumentTypes()[i].isa<Torch::NoneType>())
|
||||
if (isa<Torch::NoneType>(func.getArgumentTypes()[i]))
|
||||
return false;
|
||||
}
|
||||
for (int i = 0, e = func.getNumResults(); i != e; i++) {
|
||||
if (func.getFunctionType().getResults()[i].isa<Torch::NoneType>())
|
||||
if (isa<Torch::NoneType>(func.getFunctionType().getResults()[i]))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -38,7 +38,7 @@ static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
|
|||
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
||||
if (failed(resDtype))
|
||||
return false;
|
||||
return resDtype->isa<mlir::FloatType>();
|
||||
return isa<mlir::FloatType>(*resDtype);
|
||||
}
|
||||
|
||||
// Helper function to compute the return type of the reduction function.
|
||||
|
@ -99,19 +99,15 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
|
|||
Operation *op, Value input, Value dim,
|
||||
bool keepDim) {
|
||||
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
||||
BaseTensorType valueType =
|
||||
computeReductionType(rewriter, op, cast<BaseTensorType>(input.getType()),
|
||||
dim, keepDim)
|
||||
.cast<BaseTensorType>();
|
||||
BaseTensorType valueType = cast<BaseTensorType>(computeReductionType(
|
||||
rewriter, op, cast<BaseTensorType>(input.getType()), dim, keepDim));
|
||||
if (!valueType)
|
||||
return nullptr;
|
||||
BaseTensorType indexType =
|
||||
valueType
|
||||
.getWithSizesAndDtype(
|
||||
cast<BaseTensorType>(valueType.getWithSizesAndDtype(
|
||||
!valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
|
||||
: llvm::ArrayRef(valueType.getSizes()),
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed))
|
||||
.cast<BaseTensorType>();
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed)));
|
||||
return rewriter
|
||||
.create<AtenMaxDimOp>(loc, valueType, indexType, input, dim, keepDimCst)
|
||||
.getValues();
|
||||
|
@ -1059,7 +1055,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenEyeMOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto outType = op.getType().dyn_cast<BaseTensorType>();
|
||||
auto outType = dyn_cast<BaseTensorType>(op.getType());
|
||||
if (!outType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
|
@ -1659,11 +1655,9 @@ public:
|
|||
unsigned inputRank = *maybeInputRank;
|
||||
if (!indicesTensorType.hasSizes())
|
||||
return failure();
|
||||
BaseTensorType valueTensorType =
|
||||
inputType
|
||||
.getWithSizesAndDtype(indicesTensorType.getOptionalSizes(),
|
||||
inputType.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
BaseTensorType valueTensorType = cast<BaseTensorType>(
|
||||
inputType.getWithSizesAndDtype(indicesTensorType.getOptionalSizes(),
|
||||
inputType.getOptionalDtype()));
|
||||
|
||||
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
|
||||
// `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so
|
||||
|
@ -1671,10 +1665,8 @@ public:
|
|||
// happens on the 0th dimension.
|
||||
if (isa<Torch::NoneType>(dim.getType())) {
|
||||
BaseTensorType flattenType =
|
||||
inputType
|
||||
.getWithSizesAndDtype({kUnknownSize},
|
||||
inputType.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
cast<BaseTensorType>(inputType.getWithSizesAndDtype(
|
||||
{kUnknownSize}, inputType.getOptionalDtype()));
|
||||
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value end = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(inputRank - 1));
|
||||
|
@ -3003,7 +2995,7 @@ public:
|
|||
bool dimIsNone = false;
|
||||
int64_t dim;
|
||||
Value dimValue = op.getDim();
|
||||
if (dimValue.getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(dimValue.getType())) {
|
||||
dimIsNone = true;
|
||||
dim = inputRank - 1;
|
||||
} else {
|
||||
|
@ -3887,10 +3879,9 @@ public:
|
|||
gradOutputViewSizesInt[0] = kUnknownSize;
|
||||
gradOutputViewSizesInt[1] = 1;
|
||||
BaseTensorType gradOutputTypeForView =
|
||||
gradOutputTy
|
||||
.getWithSizesAndDtype(llvm::ArrayRef(gradOutputViewSizesInt),
|
||||
gradOutputTy.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
cast<BaseTensorType>(gradOutputTy.getWithSizesAndDtype(
|
||||
llvm::ArrayRef(gradOutputViewSizesInt),
|
||||
gradOutputTy.getOptionalDtype()));
|
||||
Value gradOutputView = rewriter.create<Torch::AtenViewOp>(
|
||||
loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList);
|
||||
|
||||
|
@ -3918,10 +3909,9 @@ public:
|
|||
}
|
||||
|
||||
BaseTensorType gradWeightTy =
|
||||
inputTransposedTy
|
||||
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt),
|
||||
inputTransposedTy.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
cast<BaseTensorType>(inputTransposedTy.getWithSizesAndDtype(
|
||||
llvm::ArrayRef(gradWeightSizesInt),
|
||||
inputTransposedTy.getOptionalDtype()));
|
||||
|
||||
Value numGroup = rewriter.create<AtenSizeIntOp>(loc, input, cstZero);
|
||||
gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
|
||||
|
@ -3937,10 +3927,9 @@ public:
|
|||
for (unsigned i = 0; i < gradWeightTy.getSizes().size() - 2; i++) {
|
||||
gradWeightSizesInt[i + 2] = weightSizes[i + 2];
|
||||
BaseTensorType gradWeightNarrowTy =
|
||||
gradWeightTy
|
||||
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt),
|
||||
gradWeightTy.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
cast<BaseTensorType>(gradWeightTy.getWithSizesAndDtype(
|
||||
llvm::ArrayRef(gradWeightSizesInt),
|
||||
gradWeightTy.getOptionalDtype()));
|
||||
|
||||
Value dim = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i + 2));
|
||||
|
@ -3970,10 +3959,9 @@ public:
|
|||
gradWeightViewShapeValue);
|
||||
|
||||
BaseTensorType gradWeightTypeForView =
|
||||
gradWeightTy
|
||||
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightViewShapeInt),
|
||||
gradWeightTy.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
cast<BaseTensorType>(gradWeightTy.getWithSizesAndDtype(
|
||||
llvm::ArrayRef(gradWeightViewShapeInt),
|
||||
gradWeightTy.getOptionalDtype()));
|
||||
gradWeight = rewriter.create<Torch::AtenViewOp>(
|
||||
loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList);
|
||||
|
||||
|
@ -3986,10 +3974,9 @@ public:
|
|||
gradWeightViewShapeInt[gradWeightDimsOrder[i]]);
|
||||
}
|
||||
BaseTensorType gradWeightTypeForMoveDim =
|
||||
gradWeightTy
|
||||
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightMoveDimShape),
|
||||
gradWeightTy.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
cast<BaseTensorType>(gradWeightTy.getWithSizesAndDtype(
|
||||
llvm::ArrayRef(gradWeightMoveDimShape),
|
||||
gradWeightTy.getOptionalDtype()));
|
||||
|
||||
gradWeight = rewriter.create<AtenMovedimIntOp>(
|
||||
loc, gradWeightTypeForMoveDim, gradWeight, /*source=*/cstZero,
|
||||
|
@ -4009,8 +3996,7 @@ public:
|
|||
Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||
loc, transposedType, gradOutput, cstZero, cstOne);
|
||||
// Convolve input with grad_output.
|
||||
if (failed(
|
||||
getTransposedType(op.getResultTypes()[1].cast<BaseTensorType>(),
|
||||
if (failed(getTransposedType(cast<BaseTensorType>(op.getResultTypes()[1]),
|
||||
0, 1, transposedType)))
|
||||
return failure();
|
||||
gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
|
||||
|
@ -4063,7 +4049,7 @@ public:
|
|||
|
||||
// TODO: Handle integer type operands.
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
||||
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non-floating point dtype");
|
||||
}
|
||||
|
@ -4125,7 +4111,7 @@ public:
|
|||
MLIRContext *context = op.getContext();
|
||||
|
||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>() ||
|
||||
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype()) ||
|
||||
!isNoneOrFloatDtype(context, dtype)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only floating-point type is supported");
|
||||
|
@ -4133,7 +4119,7 @@ public:
|
|||
|
||||
SmallVector<Value> dimListElements;
|
||||
if (!getListConstructElements(dimList, dimListElements) &&
|
||||
!dimList.getType().isa<Torch::NoneType>()) {
|
||||
!isa<Torch::NoneType>(dimList.getType())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected `dim` to be `None` or constructed from list construct");
|
||||
}
|
||||
|
@ -4215,7 +4201,7 @@ public:
|
|||
return success();
|
||||
}
|
||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>())
|
||||
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support floating type input for training mode");
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
||||
|
@ -4243,7 +4229,7 @@ public:
|
|||
Value input = op.getInput();
|
||||
Value prob = op.getP();
|
||||
bool train = false;
|
||||
if (!op.getTrain().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getTrain().getType())) {
|
||||
if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "train must be a boolean constant or none");
|
||||
|
@ -4263,7 +4249,7 @@ public:
|
|||
return success();
|
||||
}
|
||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
||||
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support floating type input for training mode");
|
||||
}
|
||||
|
@ -4332,7 +4318,7 @@ public:
|
|||
Value self = op.getSelf();
|
||||
BaseTensorType inputTensorTy = cast<BaseTensorType>(self.getType());
|
||||
if (!inputTensorTy.hasDtype() ||
|
||||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
||||
!isa<mlir::FloatType>(inputTensorTy.getDtype())) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Only aten.std support floating type");
|
||||
}
|
||||
|
@ -4388,7 +4374,7 @@ public:
|
|||
Value self = op.getSelf();
|
||||
BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType());
|
||||
if (!inputTensorType.hasDtype() ||
|
||||
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
|
||||
!isa<mlir::FloatType>(inputTensorType.getDtype())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "aten.std.dim expects input tensor of floating-point type");
|
||||
}
|
||||
|
@ -4413,7 +4399,7 @@ public:
|
|||
Value self = op.getSelf();
|
||||
BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType());
|
||||
if (!inputTensorType.hasDtype() ||
|
||||
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
|
||||
!isa<mlir::FloatType>(inputTensorType.getDtype())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"aten.std.correction expects input tensor of floating-point type");
|
||||
|
@ -4506,7 +4492,7 @@ public:
|
|||
Value input = op.getSelf();
|
||||
Type resultType = op.getType();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
||||
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype())) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support floating-point type");
|
||||
}
|
||||
|
@ -4547,7 +4533,7 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,
|
|||
op, "can't decompose bernoulli like ops without sizes or dtype");
|
||||
}
|
||||
// The `prob` is expected to be a float type tensor.
|
||||
if (!probType.getDtype().isa<mlir::FloatType>()) {
|
||||
if (!isa<mlir::FloatType>(probType.getDtype())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "probabilities must be a float type tensor");
|
||||
}
|
||||
|
@ -4582,7 +4568,7 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(op.getGenerator().getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to be None because only global default "
|
||||
"generator is supported");
|
||||
|
@ -4640,7 +4626,7 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
Value prob = op.getP();
|
||||
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(op.getGenerator().getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to be None because only global default "
|
||||
"generator is supported");
|
||||
|
@ -4665,7 +4651,7 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenExponentialOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(op.getGenerator().getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to be None because only global default "
|
||||
"generator is supported");
|
||||
|
@ -4706,7 +4692,7 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenNormalFunctionalOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(op.getGenerator().getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to be None because only global default "
|
||||
"generator is supported");
|
||||
|
@ -4984,10 +4970,10 @@ class DecomposeAtenNativeLayerNormOp
|
|||
|
||||
Value weight = op.getWeight();
|
||||
Value bias = op.getBias();
|
||||
if (!weight.getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(weight.getType())) {
|
||||
out = rewriter.create<AtenMulTensorOp>(loc, out.getType(), out, weight);
|
||||
}
|
||||
if (!bias.getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(bias.getType())) {
|
||||
out =
|
||||
rewriter.create<AtenAddTensorOp>(loc, out.getType(), out, bias, one);
|
||||
}
|
||||
|
@ -5238,13 +5224,13 @@ class DecomposeAtenNativeGroupNormOp
|
|||
loc, ListType::get(IntType::get(context)), viewShape);
|
||||
|
||||
Value groupNormOutput = reshapedOutput;
|
||||
if (!weight.getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(weight.getType())) {
|
||||
auto weightReshaped = rewriter.create<AtenViewOp>(
|
||||
loc, baseType, weight, /*shape=*/viewShapeSizeList);
|
||||
groupNormOutput = rewriter.create<AtenMulTensorOp>(
|
||||
loc, inputType, groupNormOutput, weightReshaped);
|
||||
}
|
||||
if (!bias.getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(bias.getType())) {
|
||||
auto biasReshaped = rewriter.create<AtenViewOp>(
|
||||
loc, baseType, bias, /*shape=*/viewShapeSizeList);
|
||||
groupNormOutput = rewriter.create<AtenAddTensorOp>(
|
||||
|
@ -5297,8 +5283,8 @@ class DecomposeAtenNativeBatchNormOp
|
|||
|
||||
// In the inference mode, the `runningMean` and `runningVar` must not be
|
||||
// None.
|
||||
if (runningMean.getType().isa<Torch::NoneType>() ||
|
||||
runningVar.getType().isa<Torch::NoneType>())
|
||||
if (isa<Torch::NoneType>(runningMean.getType()) ||
|
||||
isa<Torch::NoneType>(runningVar.getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "running stats must not be None in inference mode");
|
||||
|
||||
|
@ -5354,7 +5340,7 @@ class DecomposeAtenNativeBatchNormOp
|
|||
// 2. bias = bias.view(1, C, 1?, 1?, 1?)
|
||||
// 3. output = normalizedInput * weight + bias
|
||||
Value batchNormOutput = normalizedInput;
|
||||
if (!weight.getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(weight.getType())) {
|
||||
// Rank of `weight` must be exactly 1.
|
||||
std::optional<unsigned> weightRank = getTensorRank(weight);
|
||||
if (!weightRank || *weightRank != 1)
|
||||
|
@ -5364,7 +5350,7 @@ class DecomposeAtenNativeBatchNormOp
|
|||
batchNormOutput = rewriter.create<AtenMulTensorOp>(
|
||||
loc, batchNormOutput.getType(), batchNormOutput, weight);
|
||||
}
|
||||
if (!bias.getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(bias.getType())) {
|
||||
// Rank of `bias` must be exactly 1.
|
||||
std::optional<unsigned> biasRank = getTensorRank(bias);
|
||||
if (!biasRank || *biasRank != 1)
|
||||
|
@ -5444,7 +5430,7 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
|
|||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value dtype = op.getDtype();
|
||||
if (dtype.getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(dtype.getType())) {
|
||||
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
|
||||
if (!tensorType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -5518,7 +5504,7 @@ public:
|
|||
return transposeWeight;
|
||||
};
|
||||
|
||||
if (bias.getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(bias.getType())) {
|
||||
auto weightRank = weightType.getSizes().size();
|
||||
if (weightRank > 2 || weightRank <= 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -5622,7 +5608,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenNewFullOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value dtype = op.getDtype();
|
||||
if (dtype.getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(dtype.getType())) {
|
||||
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
|
||||
if (!tensorType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -5718,7 +5704,7 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
|
|||
PatternRewriter &rewriter) const override {
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||
Value dtype = op.getDtype();
|
||||
if (dtype.getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(dtype.getType())) {
|
||||
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
|
||||
if (!tensorType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -5743,9 +5729,9 @@ class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
|
|||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Value value = op.getValue();
|
||||
if (value.getType().isa<Torch::OptionalType>())
|
||||
if (isa<Torch::OptionalType>(value.getType()))
|
||||
return rewriter.notifyMatchFailure(op, "optional type not supported");
|
||||
if (value.getType().isa<Torch::NoneType>())
|
||||
if (isa<Torch::NoneType>(value.getType()))
|
||||
value = rewriter.create<Torch::ConstantFloatOp>(
|
||||
op.getLoc(), rewriter.getF64FloatAttr(0));
|
||||
|
||||
|
@ -5765,7 +5751,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenToDtypeLayoutOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO: Add support for pinMemory arg equal to `True`.
|
||||
if (!op.getPinMemory().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getPinMemory().getType())) {
|
||||
bool pinMemory;
|
||||
if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -5776,7 +5762,7 @@ public:
|
|||
}
|
||||
|
||||
// TODO: Add support for device arg other than cpu.
|
||||
if (!op.getDevice().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getDevice().getType())) {
|
||||
std::string device;
|
||||
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -5788,7 +5774,7 @@ public:
|
|||
|
||||
// TODO: Add support for non-strided layout.
|
||||
// torch.layout is by default strided i.e. 0.
|
||||
if (!op.getLayout().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getLayout().getType())) {
|
||||
int64_t tensorLayout;
|
||||
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -6254,7 +6240,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
|||
Type newOutputType = outputTensorType.getWithSizesAndDtype(
|
||||
outputTensorType.getSizes(), rewriter.getF64Type());
|
||||
if (!inputTensorTy.hasDtype() ||
|
||||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
||||
!isa<mlir::FloatType>(inputTensorTy.getDtype())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "support floating-point type input only");
|
||||
}
|
||||
|
@ -6391,14 +6377,14 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
int64_t correctionValInt;
|
||||
double correctionValFloat = 1.0;
|
||||
if (!op.getCorrection().getType().isa<Torch::NoneType>()) {
|
||||
if (op.getCorrection().getType().isa<Torch::FloatType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getCorrection().getType())) {
|
||||
if (isa<Torch::FloatType>(op.getCorrection().getType())) {
|
||||
if (!matchPattern(op.getCorrection(),
|
||||
m_TorchConstantFloat(&correctionValFloat)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only support constant int or float correction value for "
|
||||
"aten.var");
|
||||
} else if (op.getCorrection().getType().isa<Torch::IntType>()) {
|
||||
} else if (isa<Torch::IntType>(op.getCorrection().getType())) {
|
||||
if (!matchPattern(op.getCorrection(),
|
||||
m_TorchConstantInt(&correctionValInt)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -6525,11 +6511,9 @@ public:
|
|||
if (!inputType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected the input tensor to have sizes");
|
||||
BaseTensorType subType =
|
||||
inputType
|
||||
.getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()),
|
||||
resultType.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
BaseTensorType subType = cast<BaseTensorType>(
|
||||
inputType.getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()),
|
||||
resultType.getOptionalDtype()));
|
||||
|
||||
Value sub =
|
||||
createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget());
|
||||
|
@ -6566,7 +6550,7 @@ public:
|
|||
Location loc = op->getLoc();
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
Value ord = op.getP();
|
||||
if (ord.getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(ord.getType())) {
|
||||
ord = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(2.0));
|
||||
}
|
||||
|
@ -6609,10 +6593,8 @@ public:
|
|||
loc, rewriter.getF64FloatAttr((double)cstHigh));
|
||||
|
||||
BaseTensorType floatResultType =
|
||||
resultTensorType
|
||||
.getWithSizesAndDtype(resultTensorType.getSizes(),
|
||||
rewriter.getF32Type())
|
||||
.cast<BaseTensorType>();
|
||||
cast<BaseTensorType>(resultTensorType.getWithSizesAndDtype(
|
||||
resultTensorType.getSizes(), rewriter.getF32Type()));
|
||||
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
||||
loc, floatResultType, op.getSize(), /*dtype=*/none,
|
||||
/*layout=*/op.getLayout(),
|
||||
|
@ -6704,7 +6686,7 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PrimsVarOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!op.getOutputDtype().getType().isa<Torch::NoneType>())
|
||||
if (!isa<Torch::NoneType>(op.getOutputDtype().getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented non-None dtype for prims::var op");
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
||||
|
@ -6816,7 +6798,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenRandnLikeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
||||
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType())) {
|
||||
int64_t memoryFormat;
|
||||
if (!matchPattern(op.getMemoryFormat(),
|
||||
m_TorchConstantInt(&memoryFormat)))
|
||||
|
@ -6913,8 +6895,8 @@ public:
|
|||
op.getDevice(), op.getPinMemory());
|
||||
// calculate (end - start) / (steps - 1)
|
||||
Value sub;
|
||||
if (op.getEnd().getType().isa<Torch::FloatType>() ||
|
||||
op.getStart().getType().isa<Torch::FloatType>()) {
|
||||
if (isa<Torch::FloatType>(op.getEnd().getType()) ||
|
||||
isa<Torch::FloatType>(op.getStart().getType())) {
|
||||
sub = rewriter.create<AtenSubOp>(loc, Torch::FloatType::get(context),
|
||||
op.getEnd(), op.getStart());
|
||||
} else {
|
||||
|
@ -6930,7 +6912,7 @@ public:
|
|||
}
|
||||
// to dtype
|
||||
Value result;
|
||||
if (!op.getDtype().getType().isa<Torch::NoneType>()) {
|
||||
if (!isa<Torch::NoneType>(op.getDtype().getType())) {
|
||||
result = rewriter.create<AtenToDtypeOp>(
|
||||
loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal,
|
||||
/*copy=*/falseVal, /*memory_format=*/none);
|
||||
|
@ -7344,11 +7326,8 @@ public:
|
|||
|
||||
auto selfType = cast<BaseTensorType>(self.getType());
|
||||
auto indexType = cast<BaseTensorType>(index.getType());
|
||||
BaseTensorType srcType =
|
||||
selfType
|
||||
.getWithSizesAndDtype(indexType.getOptionalSizes(),
|
||||
selfType.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
BaseTensorType srcType = cast<BaseTensorType>(selfType.getWithSizesAndDtype(
|
||||
indexType.getOptionalSizes(), selfType.getOptionalDtype()));
|
||||
Value src =
|
||||
createInitTensor(rewriter, loc, srcType, op.getValue(), sizeList);
|
||||
rewriter.replaceOpWithNewOp<AtenScatterSrcOp>(op, op.getType(), self,
|
||||
|
@ -7372,7 +7351,7 @@ public:
|
|||
"expected result type to have dtype");
|
||||
}
|
||||
// TODO: support complex type in future.
|
||||
if (outType.getDtype().isa<mlir::ComplexType>()) {
|
||||
if (isa<mlir::ComplexType>(outType.getDtype())) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"doesn't support complex type now");
|
||||
}
|
||||
|
@ -7488,7 +7467,7 @@ static FailureOr<Value> createNewIndices(Operation *op,
|
|||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasSizes()) {
|
||||
return failure();
|
||||
}
|
||||
|
@ -7497,7 +7476,7 @@ static FailureOr<Value> createNewIndices(Operation *op,
|
|||
|
||||
int64_t maxIndexRank = 0;
|
||||
for (auto index : oldIndices) {
|
||||
auto indexType = index.getType().dyn_cast<BaseTensorType>();
|
||||
auto indexType = dyn_cast<BaseTensorType>(index.getType());
|
||||
if (!indexType) // None index
|
||||
continue;
|
||||
if (!indexType.hasSizes())
|
||||
|
@ -7586,15 +7565,13 @@ public:
|
|||
int64_t inputRank = inputSizes.size();
|
||||
|
||||
auto isTensor = [](Value v) {
|
||||
return v.getType().isa<Torch::BaseTensorType>();
|
||||
return isa<Torch::BaseTensorType>(v.getType());
|
||||
};
|
||||
|
||||
// directly replace aten.Index.Tensor with aten.index.Tensor_hacked_twin
|
||||
if (llvm::all_of(indices, isTensor)) {
|
||||
// By default, we regard the first index type as the list element type.
|
||||
auto indexElemType = indices[0]
|
||||
.getType()
|
||||
.template cast<BaseTensorType>()
|
||||
auto indexElemType = cast<BaseTensorType>(indices[0].getType())
|
||||
.getWithSizesAndDtype(std::nullopt, nullptr);
|
||||
auto newIndices = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(indexElemType), indices);
|
||||
|
@ -7684,7 +7661,7 @@ public:
|
|||
"failed to get elements of `indices`");
|
||||
|
||||
auto input = op.getSelf();
|
||||
auto inputType = input.getType().template cast<BaseTensorType>();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only input with shape information is supported");
|
||||
|
@ -7693,15 +7670,13 @@ public:
|
|||
int64_t inputRank = inputSizes.size();
|
||||
|
||||
auto isTensor = [](Value v) {
|
||||
return v.getType().isa<Torch::BaseTensorType>();
|
||||
return isa<Torch::BaseTensorType>(v.getType());
|
||||
};
|
||||
|
||||
// directly replace current op with aten.index_put.hacked_twin
|
||||
if (llvm::all_of(indices, isTensor)) {
|
||||
// By default, we regard the first index type as the list element type.
|
||||
auto indexElemType = indices[0]
|
||||
.getType()
|
||||
.template cast<BaseTensorType>()
|
||||
auto indexElemType = cast<BaseTensorType>(indices[0].getType())
|
||||
.getWithSizesAndDtype(std::nullopt, nullptr);
|
||||
auto newIndex = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(indexElemType), indices);
|
||||
|
@ -7831,7 +7806,7 @@ public:
|
|||
|
||||
// default ord value is 2 for vector_norm
|
||||
auto ord = op.getOrd();
|
||||
if (ord.getType().isa<Torch::NoneType>()) {
|
||||
if (isa<Torch::NoneType>(ord.getType())) {
|
||||
ord = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenLinalgVectorNormOp>(
|
||||
|
|
|
@ -63,8 +63,8 @@ public:
|
|||
};
|
||||
|
||||
static bool isTypeTriviallySafe(Type type) {
|
||||
return type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
|
||||
Torch::StringType, Torch::NoneType, Torch::ValueTensorType>();
|
||||
return isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
|
||||
Torch::StringType, Torch::NoneType, Torch::ValueTensorType>(type);
|
||||
}
|
||||
|
||||
static bool isUseTreatedWithValueSemantics(OpOperand &use) {
|
||||
|
|
|
@ -36,8 +36,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
static LogicalResult checkType(Operation *op, Type type,
|
||||
bool actuallyEmitDiagnostics) {
|
||||
// Allow various scalar types that backends are expected to be able to handle.
|
||||
if (type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
|
||||
Torch::DeviceType>())
|
||||
if (isa<Torch::IntType, Torch::FloatType, Torch::BoolType, Torch::DeviceType>(
|
||||
type))
|
||||
return success();
|
||||
|
||||
// Backends are not expected to support dynamic computations on these types,
|
||||
|
|
|
@ -187,7 +187,7 @@ public:
|
|||
auto it = originalReturnTypes.find(i);
|
||||
if (it == originalReturnTypes.end())
|
||||
continue;
|
||||
auto originalType = it->second.cast<NonValueTensorType>();
|
||||
auto originalType = cast<NonValueTensorType>(it->second);
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
Value newReturnValue = copyTensorToType(rewriter, returnOp->getLoc(),
|
||||
originalType, operand.get());
|
||||
|
@ -350,7 +350,7 @@ public:
|
|||
auto it = originalTypes.find(operand.get());
|
||||
if (it == originalTypes.end())
|
||||
continue;
|
||||
auto originalType = it->second.cast<BaseTensorType>();
|
||||
auto originalType = cast<BaseTensorType>(it->second);
|
||||
rewriter.setInsertionPoint(op);
|
||||
Value newReturnValue = copyTensorToType(rewriter, op->getLoc(),
|
||||
originalType, operand.get());
|
||||
|
|
|
@ -118,7 +118,7 @@ public:
|
|||
if (auto optionalType =
|
||||
dyn_cast<OptionalType>(listType.getContainedType())) {
|
||||
if (!llvm::all_of(listConstruct.getElements(), [](Value val) {
|
||||
return val.getType().isa<NonValueTensorType, Torch::NoneType>();
|
||||
return isa<NonValueTensorType, Torch::NoneType>(val.getType());
|
||||
})) {
|
||||
rewriter.cancelOpModification(op);
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
|
|
@ -81,7 +81,7 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable(
|
|||
if (name.starts_with("valsem."))
|
||||
name = name.drop_front(strlen("valsem."));
|
||||
if (isa<OperatorOp>(op))
|
||||
name = cast<OperatorOp>(op)->getAttr("name").cast<StringAttr>().getValue();
|
||||
name = cast<StringAttr>(cast<OperatorOp>(op)->getAttr("name")).getValue();
|
||||
std::string libFuncName =
|
||||
(getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str();
|
||||
auto libFunc = library.lookupSymbol<func::FuncOp>(libFuncName);
|
||||
|
@ -191,8 +191,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
|
|||
// to match the library function signature.
|
||||
if (auto unionType = dyn_cast<Torch::UnionType>(desiredType)) {
|
||||
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
|
||||
return containedType
|
||||
.isa<Torch::IntType, Torch::FloatType, Torch::NoneType>();
|
||||
return isa<Torch::IntType, Torch::FloatType, Torch::NoneType>(
|
||||
containedType);
|
||||
}))
|
||||
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
|
||||
}
|
||||
|
|
|
@ -179,11 +179,10 @@ public:
|
|||
"should have concrete Scalar Type.");
|
||||
}
|
||||
Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType());
|
||||
auto impliedTypeFromInputType =
|
||||
auto impliedTypeFromInputType = cast<BaseTensorType>(
|
||||
cast<BaseTensorType>(originalResultType)
|
||||
.getWithSizesAndDtype(originalResultType.getOptionalSizes(),
|
||||
inputType)
|
||||
.cast<BaseTensorType>();
|
||||
inputType));
|
||||
|
||||
op.getResult().setType(impliedTypeFromInputType);
|
||||
return success();
|
||||
|
|
|
@ -97,11 +97,10 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
|
|||
}
|
||||
|
||||
auto originalResultType = cast<BaseTensorType>(result.getType());
|
||||
auto impliedTypesFromShape =
|
||||
auto impliedTypesFromShape = cast<BaseTensorType>(
|
||||
cast<BaseTensorType>(originalResultType)
|
||||
.getWithSizesAndDtype(ArrayRef(sizes),
|
||||
originalResultType.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
originalResultType.getOptionalDtype()));
|
||||
|
||||
return updateCalculateOpResultTypes(op, resultNum, impliedTypesFromShape,
|
||||
rewriter);
|
||||
|
|
|
@ -74,7 +74,7 @@ LogicalResult FromBuiltinTensorOp::verify() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult FromI1Op::fold(FoldAdaptor adaptor) {
|
||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::BoolAttr>();
|
||||
auto attr = dyn_cast_or_null<mlir::BoolAttr>(adaptor.getOperand());
|
||||
if (attr) {
|
||||
return attr;
|
||||
} else {
|
||||
|
@ -87,7 +87,7 @@ OpFoldResult FromI1Op::fold(FoldAdaptor adaptor) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ToI1Op::fold(FoldAdaptor adaptor) {
|
||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::BoolAttr>();
|
||||
auto attr = dyn_cast_or_null<mlir::BoolAttr>(adaptor.getOperand());
|
||||
if (attr) {
|
||||
return attr;
|
||||
} else {
|
||||
|
@ -100,7 +100,7 @@ OpFoldResult ToI1Op::fold(FoldAdaptor adaptor) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) {
|
||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
auto attr = dyn_cast_or_null<mlir::IntegerAttr>(adaptor.getOperand());
|
||||
if (attr) {
|
||||
return attr;
|
||||
} else {
|
||||
|
@ -113,7 +113,7 @@ OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) {
|
||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
auto attr = dyn_cast_or_null<mlir::IntegerAttr>(adaptor.getOperand());
|
||||
if (attr) {
|
||||
return attr;
|
||||
} else {
|
||||
|
@ -126,7 +126,7 @@ OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) {
|
||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>();
|
||||
auto attr = dyn_cast_or_null<mlir::FloatAttr>(adaptor.getOperand());
|
||||
if (attr) {
|
||||
return attr;
|
||||
} else {
|
||||
|
@ -139,7 +139,7 @@ OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult FromF64Op::fold(FoldAdaptor adaptor) {
|
||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>();
|
||||
auto attr = dyn_cast_or_null<mlir::FloatAttr>(adaptor.getOperand());
|
||||
if (attr) {
|
||||
return attr;
|
||||
} else {
|
||||
|
|
|
@ -91,7 +91,7 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target,
|
|||
return std::nullopt;
|
||||
// Other input type to be converted to i64 are handled by other
|
||||
// materializers.
|
||||
if (!inputs[0].getType().isa<Torch::IntType>())
|
||||
if (!isa<Torch::IntType>(inputs[0].getType()))
|
||||
return std::nullopt;
|
||||
assert(inputs.size() == 1);
|
||||
return builder.create<ToI64Op>(loc, inputs[0]).getResult();
|
||||
|
@ -145,7 +145,7 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target,
|
|||
return std::nullopt;
|
||||
// Other input type to be converted to i64 are handled by other
|
||||
// materializers.
|
||||
if (!inputs[0].getType().isa<Torch::GeneratorType>())
|
||||
if (!isa<Torch::GeneratorType>(inputs[0].getType()))
|
||||
return std::nullopt;
|
||||
assert(inputs.size() == 1);
|
||||
return builder.create<GeneratorToI64Op>(loc, inputs[0]).getResult();
|
||||
|
|
|
@ -56,7 +56,7 @@ void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); }
|
|||
static bool isArgMemRefTypeValid(Type type) {
|
||||
if (auto memRefType = dyn_cast<MemRefType>(type)) {
|
||||
Type elemTy = memRefType.getElementType();
|
||||
if (elemTy.isa<Float16Type, Float32Type, Float64Type>()) {
|
||||
if (isa<Float16Type, Float32Type, Float64Type>(elemTy)) {
|
||||
return true;
|
||||
} else if (auto integerTy = dyn_cast<IntegerType>(elemTy)) {
|
||||
if (integerTy.isSignlessInteger(64))
|
||||
|
@ -70,7 +70,7 @@ static bool isArgMemRefTypeValid(Type type) {
|
|||
if (integerTy.isSignlessInteger(1))
|
||||
return true;
|
||||
} else if (auto complexTy = dyn_cast<ComplexType>(elemTy)) {
|
||||
return complexTy.getElementType().isa<Float32Type, Float64Type>();
|
||||
return isa<Float32Type, Float64Type>(complexTy.getElementType());
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
|
Loading…
Reference in New Issue