mirror of https://github.com/llvm/torch-mlir
[MLIR][Torch] Resolve styling issues related to aten zeros/ones op
https://github.com/llvm/torch-mlir/pull/464#discussion_r765065092 Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/470/head
parent
f34eb66124
commit
0a0a1b4476
|
@ -3329,41 +3329,43 @@ struct ConvertAtenOnesZerosOp : ConversionPattern {
|
||||||
return failure();
|
return failure();
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
|
|
||||||
SmallVector<Value, 3> opArguments;
|
Value size, layout, pin_memory;
|
||||||
int64_t elementValue;
|
int64_t elementValue;
|
||||||
|
|
||||||
if (AtenOnesOp onesOp = dyn_cast<AtenOnesOp>(op)) {
|
if (AtenOnesOp onesOp = dyn_cast<AtenOnesOp>(op)) {
|
||||||
opArguments.insert(opArguments.end(),
|
size = onesOp.size();
|
||||||
{onesOp.size(), onesOp.layout(), onesOp.pin_memory()});
|
layout = onesOp.layout();
|
||||||
|
pin_memory = onesOp.pin_memory();
|
||||||
elementValue = 1;
|
elementValue = 1;
|
||||||
} else if (AtenZerosOp zerosOp = dyn_cast<AtenZerosOp>(op)) {
|
} else if (AtenZerosOp zerosOp = dyn_cast<AtenZerosOp>(op)) {
|
||||||
opArguments.insert(opArguments.end(), {zerosOp.size(), zerosOp.layout(),
|
size = zerosOp.size();
|
||||||
zerosOp.pin_memory()});
|
layout = zerosOp.layout();
|
||||||
|
pin_memory = zerosOp.pin_memory();
|
||||||
elementValue = 0;
|
elementValue = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// We ignore device, but add simple asserts for unimplemented kwargs
|
// We ignore device, but add simple asserts for unimplemented kwargs
|
||||||
if (!opArguments[1].getType().isa<Torch::NoneType>())
|
if (!layout.getType().isa<Torch::NoneType>())
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only default layout is supported");
|
"only default layout is supported");
|
||||||
|
|
||||||
bool pinMemory = false;
|
bool pinMemory = false;
|
||||||
if (!opArguments[2].getType().isa<Torch::NoneType>() &&
|
if (!pin_memory.getType().isa<Torch::NoneType>() &&
|
||||||
!matchPattern(opArguments[2], m_TorchConstantBool(&pinMemory))) {
|
!matchPattern(pin_memory, m_TorchConstantBool(&pinMemory))) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "pin_memory must be constant bool or None");
|
op, "pin_memory must be constant bool or None");
|
||||||
}
|
}
|
||||||
if (pinMemory)
|
if (pinMemory)
|
||||||
return rewriter.notifyMatchFailure(op, "memory pinning not supported");
|
return rewriter.notifyMatchFailure(op, "memory pinning not supported");
|
||||||
|
|
||||||
SmallVector<Value> size, sizeIndex;
|
SmallVector<Value> sizes, sizeIndex;
|
||||||
if (!getListConstructElements(opArguments[0], size)) {
|
if (!getListConstructElements(size, sizes)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "size must be created by ListConstruct");
|
op, "size must be created by ListConstruct");
|
||||||
}
|
}
|
||||||
size = getTypeConvertedValues(rewriter, loc, getTypeConverter(), size);
|
sizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(), sizes);
|
||||||
for (size_t i = 0; i < size.size(); i++)
|
for (size_t i = 0; i < sizes.size(); i++)
|
||||||
sizeIndex.push_back(castIntToIndex(rewriter, loc, size[i]));
|
sizeIndex.push_back(castIntToIndex(rewriter, loc, sizes[i]));
|
||||||
|
|
||||||
RankedTensorType newResultType =
|
RankedTensorType newResultType =
|
||||||
getTypeConverter()
|
getTypeConverter()
|
||||||
|
|
Loading…
Reference in New Issue