package org.eclipse.january.dataset;

import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/eclipse/january/dataset/BroadcastIteratorTest.class */
public class BroadcastIteratorTest {
    @Test
    public void testBroadcastShape() {
        Dataset ones = DatasetFactory.ones(DoubleDataset.class, new int[0]);
        checkBroadcastShape(ones, "scalar as scalar", new int[0], new int[0], new int[0]);
        checkBroadcastShape(ones, "scalar as [1]", new int[]{1}, new int[]{1}, 1);
        checkBroadcastShape(ones, "scalar as [3]", new int[]{1}, new int[]{3}, 3);
        checkBroadcastShape(ones, "scalar as [3,4]", new int[]{1, 1}, new int[]{3, 4}, 3, 4);
        Dataset ones2 = DatasetFactory.ones(DoubleDataset.class, new int[]{1});
        checkBroadcastShape(ones2, "[1] as scalar", new int[]{1}, new int[0], new int[0]);
        checkBroadcastShape(ones2, "[1] as [1]", new int[]{1}, new int[]{1}, 1);
        checkBroadcastShape(ones2, "[1] as [3]", new int[]{1}, new int[]{3}, 3);
        checkBroadcastShape(ones2, "[1] as [3,4]", new int[]{1, 1}, new int[]{3, 4}, 3, 4);
        Dataset ones3 = DatasetFactory.ones(DoubleDataset.class, new int[]{1, 1});
        checkBroadcastShape(ones3, "[1,1] as scalar", new int[]{1, 1}, new int[0], new int[0]);
        checkBroadcastShape(ones3, "[1,1] as [1]", new int[]{1, 1}, new int[]{1, 1}, 1);
        checkBroadcastShape(ones3, "[1,1] as [3]", new int[]{1, 1}, new int[]{1, 3}, 3);
        checkBroadcastShape(ones3, "[1,1] as [1,3]", new int[]{1, 1}, new int[]{1, 3}, 1, 3);
        checkBroadcastShape(ones3, "[1,1] as [3,4]", new int[]{1, 1}, new int[]{3, 4}, 3, 4);
        Dataset ones4 = DatasetFactory.ones(DoubleDataset.class, new int[]{3});
        checkBroadcastShape(ones4, "[3] as scalar", null, null, new int[0]);
        checkBroadcastShape(ones4, "[3] as [1]", new int[]{3}, new int[]{1}, 1);
        checkBroadcastShape(ones4, "[3] as [3]", new int[]{3}, new int[]{3}, 3);
        checkBroadcastShape(ones4, "[3] as [1,3]", new int[]{1, 3}, new int[]{1, 3}, 1, 3);
        checkBroadcastShape(ones4, "[3] as [3,4]", null, null, 3, 4);
        Dataset ones5 = DatasetFactory.ones(DoubleDataset.class, new int[]{3, 1});
        checkBroadcastShape(ones5, "[3,1] as scalar", null, null, new int[0]);
        checkBroadcastShape(ones5, "[3,1] as [1]", new int[]{3, 1}, new int[]{1, 1}, 1);
        checkBroadcastShape(ones5, "[3,1] as [3]", new int[]{3, 1}, new int[]{1, 3}, 3);
        checkBroadcastShape(ones5, "[3,1] as [1,3]", new int[]{3, 1}, new int[]{1, 3}, 1, 3);
        checkBroadcastShape(ones5, "[3,1] as [3,4]", new int[]{3, 1}, new int[]{3, 4}, 3, 4);
        checkBroadcastShape(ones5, "[3,1] as [6,3,4]", new int[]{1, 3, 1}, new int[]{6, 3, 4}, 6, 3, 4);
        checkBroadcastShape(ones5, "[3,1] as [3,4,6]", null, null, 3, 4, 6);
        Dataset ones6 = DatasetFactory.ones(DoubleDataset.class, new int[]{1, 3});
        checkBroadcastShape(ones6, "[1,3] as scalar", null, null, new int[0]);
        checkBroadcastShape(ones6, "[1,3] as [1]", new int[]{1, 3}, new int[]{1, 1}, 1);
        checkBroadcastShape(ones6, "[1,3] as [3]", new int[]{1, 3}, new int[]{1, 3}, 3);
        checkBroadcastShape(ones6, "[1,3] as [1,3]", new int[]{1, 3}, new int[]{1, 3}, 1, 3);
        checkBroadcastShape(ones6, "[1,3] as [3,4]", null, null, 3, 4);
        checkBroadcastShape(ones6, "[1,3] as [4,3]", new int[]{1, 3}, new int[]{4, 3}, 4, 3);
        checkBroadcastShape(ones6, "[1,3] as [6,4,3]", new int[]{1, 1, 3}, new int[]{6, 4, 3}, 6, 4, 3);
        checkBroadcastShape(ones6, "[1,3] as [3,4,6]", null, null, 3, 4, 6);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void checkBroadcastShape(Dataset dataset, String str, int[] iArr, int[] iArr2, int... iArr3) {
        Assert.assertArrayEquals("Broadcasting " + str, (iArr == null && iArr2 == null) ? null : new int[]{iArr, iArr2}, BroadcastUtils.calculateBroadcastShapes(dataset.getShapeRef(), dataset.getSize(), iArr3));
    }

    @Test
    public void testBroadcastWithNoOutput() {
        Dataset reshape = DatasetFactory.createRange(5.0d, 6).reshape(new int[]{5, 1});
        Dataset reshape2 = DatasetFactory.createRange(2.0d, 8.0d, 1.0d, 6).reshape(new int[]{1, 6});
        BroadcastIterator createIterator = BroadcastIterator.createIterator(reshape, reshape2);
        Assert.assertArrayEquals("Broadcast shape", new int[]{5, 6}, createIterator.getShape());
        Dataset zeros = DatasetFactory.zeros(createIterator.getShape(), 6);
        for (int i = 0; i < 5; i++) {
            for (int i2 = 0; i2 < 6; i2++) {
                Assert.assertTrue(createIterator.hasNext());
                zeros.set(Double.valueOf(createIterator.aDouble * createIterator.bDouble), i, i2);
                Assert.assertEquals(reshape.getDouble(i, 0), createIterator.aDouble, 1.0E-15d);
                Assert.assertEquals(reshape2.getDouble(0, i2), createIterator.bDouble, 1.0E-15d);
                Assert.assertEquals(zeros.getDouble(i, i2), i * (i2 + 2.0d), 1.0E-15d);
            }
        }
        reshape2.setShape(new int[]{6});
        BroadcastIterator createIterator2 = BroadcastIterator.createIterator(reshape, reshape2);
        Assert.assertArrayEquals("Broadcast shape", new int[]{5, 6}, createIterator2.getShape());
        Dataset zeros2 = DatasetFactory.zeros(createIterator2.getShape(), 6);
        for (int i3 = 0; i3 < 5; i3++) {
            for (int i4 = 0; i4 < 6; i4++) {
                Assert.assertTrue(createIterator2.hasNext());
                zeros2.set(Double.valueOf(createIterator2.aDouble * createIterator2.bDouble), i3, i4);
                Assert.assertEquals(reshape.getDouble(i3, 0), createIterator2.aDouble, 1.0E-15d);
                Assert.assertEquals(reshape2.getDouble(i4), createIterator2.bDouble, 1.0E-15d);
                Assert.assertEquals(zeros2.getDouble(i3, i4), i3 * (i4 + 2.0d), 1.0E-15d);
            }
        }
        Dataset ones = DatasetFactory.ones(new int[]{1}, 2);
        BroadcastIterator createIterator3 = BroadcastIterator.createIterator(ones, reshape2);
        Assert.assertArrayEquals("Broadcast shape", new int[]{6}, createIterator3.getShape());
        Dataset zeros3 = DatasetFactory.zeros(createIterator3.getShape(), 6);
        for (int i5 = 0; i5 < 6; i5++) {
            Assert.assertTrue(createIterator3.hasNext());
            zeros3.set(Double.valueOf(createIterator3.aDouble * createIterator3.bDouble), i5);
            Assert.assertEquals(ones.getDouble(0), createIterator3.aDouble, 1.0E-15d);
            Assert.assertEquals(reshape2.getDouble(i5), createIterator3.bDouble, 1.0E-15d);
            Assert.assertEquals(zeros3.getDouble(i5), i5 + 2.0d, 1.0E-15d);
        }
        Dataset createFromObject = DatasetFactory.createFromObject(1);
        BroadcastIterator createIterator4 = BroadcastIterator.createIterator(createFromObject, reshape2);
        Assert.assertArrayEquals("Broadcast shape", new int[]{6}, createIterator4.getShape());
        Dataset zeros4 = DatasetFactory.zeros(createIterator4.getShape(), 6);
        for (int i6 = 0; i6 < 6; i6++) {
            Assert.assertTrue(createIterator4.hasNext());
            zeros4.set(Double.valueOf(createIterator4.aDouble * createIterator4.bDouble), i6);
            Assert.assertEquals(createFromObject.getDouble(), createIterator4.aDouble, 1.0E-15d);
            Assert.assertEquals(reshape2.getDouble(i6), createIterator4.bDouble, 1.0E-15d);
            Assert.assertEquals(zeros4.getDouble(i6), i6 + 2.0d, 1.0E-15d);
        }
        Dataset createFromObject2 = DatasetFactory.createFromObject(2);
        BroadcastIterator createIterator5 = BroadcastIterator.createIterator(createFromObject, createFromObject2);
        createIterator5.setOutputDouble(true);
        Assert.assertArrayEquals("Broadcast shape", new int[0], createIterator5.getShape());
        Dataset zeros5 = DatasetFactory.zeros(createIterator5.getShape(), 6);
        Assert.assertTrue(createIterator5.hasNext());
        zeros5.set(Double.valueOf(createIterator5.aDouble * createIterator5.bDouble));
        Assert.assertEquals(createFromObject.getDouble(), createIterator5.aDouble, 1.0E-15d);
        Assert.assertEquals(createFromObject2.getDouble(), createIterator5.bDouble, 1.0E-15d);
        Assert.assertEquals(zeros5.getDouble(), 2.0d, 1.0E-15d);
        Dataset reshape3 = DatasetFactory.createRange(5.0d, 6).reshape(new int[]{5, 1});
        Dataset sliceView = DatasetFactory.createRange(2.0d, 8.0d, 1.0d, 6).getSliceView(new Slice[]{new Slice((Integer) null, (Integer) null, 2)});
        BroadcastIterator createIterator6 = BroadcastIterator.createIterator(reshape3, sliceView);
        Assert.assertArrayEquals("Broadcast shape", new int[]{5, 3}, createIterator6.getShape());
        Dataset zeros6 = DatasetFactory.zeros(createIterator6.getShape(), 6);
        for (int i7 = 0; i7 < 5; i7++) {
            for (int i8 = 0; i8 < 3; i8++) {
                Assert.assertTrue(createIterator6.hasNext());
                zeros6.set(Double.valueOf(createIterator6.aDouble * createIterator6.bDouble), i7, i8);
                Assert.assertEquals(reshape3.getDouble(i7, 0), createIterator6.aDouble, 1.0E-15d);
                Assert.assertEquals(sliceView.getDouble(i8), createIterator6.bDouble, 1.0E-15d);
                Assert.assertEquals(zeros6.getDouble(i7, i8), i7 * ((2 * i8) + 2.0d), 1.0E-15d);
            }
        }
    }

    @Test
    public void testBroadcastWithOutput() {
        Dataset reshape = DatasetFactory.createRange(10.0d, 6).reshape(new int[]{10, 1});
        Dataset reshape2 = DatasetFactory.createRange(2.0d, 14.0d, 1.0d, 6).reshape(new int[]{1, 12});
        Dataset zeros = DatasetFactory.zeros(new int[]{10, 12}, 6);
        BroadcastIterator createIterator = BroadcastIterator.createIterator(reshape, reshape2, zeros);
        Assert.assertArrayEquals("Broadcast shape", new int[]{10, 12}, createIterator.getShape());
        for (int i = 0; i < 10; i++) {
            for (int i2 = 0; i2 < 12; i2++) {
                Assert.assertTrue(createIterator.hasNext());
                Assert.assertEquals(reshape.getDouble(i, 0), createIterator.aDouble, 1.0E-15d);
                Assert.assertEquals(reshape2.getDouble(0, i2), createIterator.bDouble, 1.0E-15d);
                zeros.setObjectAbs(createIterator.oIndex, Double.valueOf(createIterator.aDouble * createIterator.bDouble));
                Assert.assertEquals(zeros.getDouble(i, i2), i * (i2 + 2.0d), 1.0E-15d);
            }
        }
        Dataset reshape3 = DatasetFactory.createRange(120.0d, 6).reshape(new int[]{10, 12});
        Dataset reshape4 = DatasetFactory.createRange(2.0d, 14.0d, 1.0d, 6).reshape(new int[]{1, 12});
        BroadcastIterator createIterator2 = BroadcastIterator.createIterator(reshape3, reshape4, reshape3);
        Assert.assertArrayEquals("Broadcast shape", new int[]{10, 12}, createIterator2.getShape());
        for (int i3 = 0; i3 < 10; i3++) {
            for (int i4 = 0; i4 < 12; i4++) {
                Assert.assertTrue(createIterator2.hasNext());
                Assert.assertEquals(reshape3.getDouble(i3, i4), createIterator2.aDouble, 1.0E-15d);
                Assert.assertEquals(reshape4.getDouble(0, i4), createIterator2.bDouble, 1.0E-15d);
                reshape3.setObjectAbs(createIterator2.oIndex, Double.valueOf(createIterator2.aDouble * createIterator2.bDouble));
                Assert.assertEquals(reshape3.getDouble(i3, i4), ((i3 * 12) + i4) * (i4 + 2.0d), 1.0E-15d);
            }
        }
        Dataset reshape5 = DatasetFactory.createRange(10.0d, 6).reshape(new int[]{10, 1});
        Dataset reshape6 = DatasetFactory.createRange(2.0d, 122.0d, 1.0d, 6).reshape(new int[]{10, 12});
        BroadcastIterator createIterator3 = BroadcastIterator.createIterator(reshape5, reshape6, reshape6);
        Assert.assertArrayEquals("Broadcast shape", new int[]{10, 12}, createIterator3.getShape());
        for (int i5 = 0; i5 < 10; i5++) {
            for (int i6 = 0; i6 < 12; i6++) {
                Assert.assertTrue(createIterator3.hasNext());
                Assert.assertEquals(reshape5.getDouble(i5, 0), createIterator3.aDouble, 1.0E-15d);
                Assert.assertEquals(reshape6.getDouble(i5, i6), createIterator3.bDouble, 1.0E-15d);
                reshape6.setObjectAbs(createIterator3.oIndex, Double.valueOf(createIterator3.aDouble * createIterator3.bDouble));
                Assert.assertEquals(reshape6.getDouble(i5, i6), i5 * ((i5 * 12) + i6 + 2.0d), 1.0E-15d);
            }
        }
        Dataset sliceView = DatasetFactory.createRange(240.0d, 6).reshape(new int[]{20, 12}).getSliceView(new Slice[]{new Slice((Integer) null, (Integer) null, 2)});
        Dataset reshape7 = DatasetFactory.createRange(2.0d, 14.0d, 1.0d, 6).reshape(new int[]{1, 12});
        BroadcastIterator createIterator4 = BroadcastIterator.createIterator(sliceView, reshape7, sliceView);
        Assert.assertArrayEquals("Broadcast shape", new int[]{10, 12}, createIterator4.getShape());
        for (int i7 = 0; i7 < 10; i7++) {
            for (int i8 = 0; i8 < 12; i8++) {
                Assert.assertTrue(createIterator4.hasNext());
                Assert.assertEquals(sliceView.getDouble(i7, i8), createIterator4.aDouble, 1.0E-15d);
                Assert.assertEquals(reshape7.getDouble(0, i8), createIterator4.bDouble, 1.0E-15d);
                sliceView.setObjectAbs(createIterator4.oIndex, Double.valueOf(createIterator4.aDouble * createIterator4.bDouble));
                Assert.assertEquals(sliceView.getDouble(i7, i8), ((24 * i7) + i8) * (i8 + 2.0d), 1.0E-15d);
            }
        }
        Dataset createRange = DatasetFactory.createRange(12.0d, 6);
        Dataset createRange2 = DatasetFactory.createRange(2.0d, 14.0d, 1.0d, 6);
        Dataset zeros2 = DatasetFactory.zeros(new int[]{10, 12}, 6);
        BroadcastIterator createIterator5 = BroadcastIterator.createIterator(createRange, createRange2, zeros2);
        Assert.assertArrayEquals("Broadcast shape", new int[]{10, 12}, createIterator5.getShape());
        for (int i9 = 0; i9 < 10; i9++) {
            for (int i10 = 0; i10 < 12; i10++) {
                Assert.assertTrue(createIterator5.hasNext());
                Assert.assertEquals(createRange.getDouble(i10), createIterator5.aDouble, 1.0E-15d);
                Assert.assertEquals(createRange2.getDouble(i10), createIterator5.bDouble, 1.0E-15d);
                zeros2.setObjectAbs(createIterator5.oIndex, Double.valueOf(createIterator5.aDouble * createIterator5.bDouble));
                Assert.assertEquals(zeros2.getDouble(i9, i10), i10 * (i10 + 2.0d), 1.0E-15d);
            }
        }
        CompoundDataset zeros3 = DatasetFactory.zeros(3, new int[]{10, 12}, 6);
        BroadcastIterator createIterator6 = BroadcastIterator.createIterator(createRange, createRange2, zeros3);
        Assert.assertArrayEquals("Broadcast shape", new int[]{10, 12}, createIterator6.getShape());
        CompoundDataset compoundDataset = zeros3;
        int elementsPerItem = zeros3.getElementsPerItem();
        double[] dArr = new double[elementsPerItem];
        for (int i11 = 0; i11 < 10; i11++) {
            for (int i12 = 0; i12 < 12; i12++) {
                Assert.assertTrue(createIterator6.hasNext());
                Assert.assertEquals(createRange.getDouble(i12), createIterator6.aDouble, 1.0E-15d);
                Assert.assertEquals(createRange2.getDouble(i12), createIterator6.bDouble, 1.0E-15d);
                zeros3.setObjectAbs(createIterator6.oIndex, Double.valueOf(createIterator6.aDouble * createIterator6.bDouble));
                Assert.assertEquals(zeros3.getDouble(i11, i12), i12 * (i12 + 2.0d), 1.0E-15d);
                compoundDataset.getDoubleArray(dArr, i11, i12);
                for (int i13 = 1; i13 < elementsPerItem; i13++) {
                    Assert.assertEquals(dArr[i13], dArr[0], 1.0E-15d);
                }
            }
        }
    }
}
