import React, { useState } from 'react';
import { withAuthenticator } from '@aws-amplify/ui-react';
import { Amplify } from 'aws-amplify';
import { fetchAuthSession } from 'aws-amplify/auth';
import awsconfig from './aws-exports';
import '@aws-amplify/ui-react/styles.css';
import './SubmitChunks.css';
import { v4 as uuidv4 } from 'uuid';
import { createS3Client, getIdentityPoolUserId, sanitizeIdentityPoolUserId } from './chunkHelper'; 
import { S3Client, PutObjectCommand } from "@aws-sdk/client-s3"; 

Amplify.configure(awsconfig);

function SubmitTFTask() {
  const [xsFile, setXsFile] = useState(null); // For xs.csv
  const [ysFile, setYsFile] = useState(null); // For ys.csv
  const [modelFile, setModelFile] = useState(null); // For model.json
  const [weightsFile, setWeightsFile] = useState(null); // For weights.bin
  const [inputShape, setInputShape] = useState(''); // New state for input shape
  const [outputShape, setOutputShape] = useState(''); // New state for output shape
  const [message, setMessage] = useState('');
  const [batchId, setBatchId] = useState(null);
  const [taskMetadata, setTaskMetadata] = useState({});
  const [isUploading, setIsUploading] = useState(false);

  const handleXsFileChange = (event) => {
    setXsFile(event.target.files[0]);
  };

  const handleYsFileChange = (event) => {
    setYsFile(event.target.files[0]);
  };

  const handleModelFileChange = (event) => {
    setModelFile(event.target.files[0]);
  };

  const handleWeightsFileChange = (event) => {
    setWeightsFile(event.target.files[0]);
  };

  const handleTaskMetadataChange = (field, value) => {
    setTaskMetadata(prevMetadata => ({ ...prevMetadata, [field]: value }));
  };

  const uploadFileToS3 = async (s3Client, file, newBatchId, owner, type, chunkId = null) => {
    let fileKey;
    
    if (type === 'model' || type === 'weights') {
      fileKey = `chunks/${owner}/${newBatchId}/model/${file.name}`;
    } else if (type === 'xs') {
      fileKey = `chunks/${owner}/${newBatchId}/${chunkId}/xs/${file.name}`;
    } else if (type === 'ys') {
      fileKey = `chunks/${owner}/${newBatchId}/${chunkId}/ys/${file.name}`;
    } else {
      fileKey = `chunks/${owner}/${newBatchId}/${file.name}`;
    }

    const putObjectParams = {
      Bucket: "my-data-chunks-bucket",
      Key: fileKey,
      Body: file,
      ContentType: file.type,
      Metadata: {
        'chunk-id': chunkId,
        'task-type': 'tensorflow',
        'batch-id': newBatchId,
        'owner': owner,
        'file-type': type,
        'input-shape': inputShape,
        'output-shape': outputShape,
        'epochs': taskMetadata.epochs || '10',
        'loss': taskMetadata.loss || 'meanSquaredError',
        'optimizer': taskMetadata.optimizer || 'adam'
      }
    };

    await s3Client.send(new PutObjectCommand(putObjectParams));
    console.log(`${type} file uploaded to: ${fileKey}`);
    return fileKey;
  };

  const createRiddleSourceEntry = async (batchId, fileName, taskType, taskMetadata, owner, modelS3Location) => {
    const path = '/start';

    const session = await fetchAuthSession(); 
    const idToken = session.tokens.idToken.toString(); 

    const init = {
      body: JSON.stringify({
        batchId,
        fileName,
        taskType,
        taskMetadata,
        owner,
        modelS3Location
      }),
      headers: {
        Authorization: `Bearer ${idToken}`,
        'Content-Type': 'application/json'
      },
    };

    const response = await fetch(`https://0pawfsvt1a.execute-api.us-east-1.amazonaws.com/dev${path}`, {
      method: 'POST',
      headers: init.headers,
      body: init.body
    });

    if (!response.ok) {
      const errorText = await response.text();
      throw new Error(`HTTP error! status: ${response.status}, body: ${errorText}`);
    }

    console.log('Riddle Source entry created:', { batchId, taskType, taskMetadata, modelS3Location });
    return await response.json();
  };

  const handleFileUpload = async () => {
    if (!xsFile || !ysFile || !modelFile || !weightsFile || !inputShape || !outputShape) {
      setMessage('Please select all required files and provide input/output shapes.');
      return;
    }

    setIsUploading(true);

    try {
      const session = await fetchAuthSession();
      const idToken = session.tokens.idToken.toString();
      const identityPoolUserId = await getIdentityPoolUserId(idToken);
      const owner = sanitizeIdentityPoolUserId(identityPoolUserId);
      const newBatchId = uuidv4();
      setBatchId(newBatchId);

      const s3Client = createS3Client(idToken);

      // Upload model and weights at the batch level
      const modelS3Location = await uploadFileToS3(s3Client, modelFile, newBatchId, owner, 'model');
      await uploadFileToS3(s3Client, weightsFile, newBatchId, owner, 'weights');

      // Store input and output shapes in metadata
      handleTaskMetadataChange('inputShape', inputShape);
      handleTaskMetadataChange('outputShape', outputShape);

      // Create Riddle Source entry at the batch level with metadata relevant for the batch
      const taskMetadataString = JSON.stringify({
        modelS3Location,
        inputShape,
        outputShape,
        numColumns: taskMetadata.numColumns || 'undefined',
        modelType: taskMetadata.modelType || 'neural-network',
        epochs: taskMetadata.epochs || 10,
        loss: taskMetadata.loss || 'meanSquaredError',
        optimizer: taskMetadata.optimizer || 'adam'
      });

      await createRiddleSourceEntry(newBatchId, xsFile.name, 'tensorflow', taskMetadataString, owner, modelS3Location);

      // Generate a unique chunk ID for the xs and ys files
      const chunkId = uuidv4();

      // Sequentially upload the xs.csv first
      console.log('Uploading xs file...');
      const xsChunkKey = await uploadFileToS3(s3Client, xsFile, newBatchId, owner, 'xs', chunkId);
      console.log('XS file uploaded with key:', xsChunkKey);

      // After xs upload completes, upload ys.csv
      console.log('Uploading ys file...');
      const ysChunkKey = await uploadFileToS3(s3Client, ysFile, newBatchId, owner, 'ys', chunkId);
      console.log('YS file uploaded with key:', ysChunkKey);

      setMessage(`File upload initiated. Batch ID: ${newBatchId}`);
    } catch (error) {
      console.error('Error uploading files:', error);
      setMessage(`Error uploading files: ${error.message}`);
    } finally {
      setIsUploading(false);
    }
  };

  return (
    <div className="submit-task-page">
      <div className="submit-task-heading">
        <h2>Submit TensorFlow Task</h2>
      </div>
      <div className="submit-task-form">
        <div className="form-group centered">
          <label htmlFor="xs-file-upload">Upload Input Data (xs.csv)</label>
          <input
            type="file"
            id="xs-file-upload"
            accept=".csv"
            onChange={handleXsFileChange}
            required
          />
        </div>

        <div className="form-group centered">
          <label htmlFor="ys-file-upload">Upload Output Data (ys.csv)</label>
          <input
            type="file"
            id="ys-file-upload"
            accept=".csv"
            onChange={handleYsFileChange}
            required
          />
        </div>

        <div className="form-group centered">
          <label htmlFor="model-upload">Upload TensorFlow Model (model.json)</label>
          <input
            type="file"
            id="model-upload"
            accept=".json"
            onChange={handleModelFileChange}
            required
          />
        </div>

        <div className="form-group centered">
          <label htmlFor="weights-upload">Upload Weights (weights.bin)</label>
          <input
            type="file"
            id="weights-upload"
            accept=".bin"
            onChange={handleWeightsFileChange}
            required
          />
        </div>

        <div className="form-group">
          <label htmlFor="input-shape">Input Shape</label>
          <input
            type="text"
            id="input-shape"
            value={inputShape}
            onChange={(e) => setInputShape(e.target.value)}
            required
            placeholder="e.g., [10000, 10, 1]"
          />
        </div>

        <div className="form-group">
          <label htmlFor="output-shape">Output Shape</label>
          <input
            type="text"
            id="output-shape"
            value={outputShape}
            onChange={(e) => setOutputShape(e.target.value)}
            required
            placeholder="e.g., [10000, 10, 1]"
          />
        </div>

        <div className="form-group">
          <label htmlFor="epochs">Epochs</label>
          <input
            type="number"
            id="epochs"
            value={taskMetadata.epochs || 10}
            onChange={(e) => handleTaskMetadataChange('epochs', e.target.value)}
            required
          />
        </div>

        <div className="form-group">
          <label htmlFor="loss">Loss Function</label>
          <input
            type="text"
            id="loss"
            value={taskMetadata.loss || 'meanSquaredError'}
            onChange={(e) => handleTaskMetadataChange('loss', e.target.value)}
            required
          />
        </div>

        <div className="form-group">
          <label htmlFor="optimizer">Optimizer</label>
          <input
            type="text"
            id="optimizer"
            value={taskMetadata.optimizer || 'adam'}
            onChange={(e) => handleTaskMetadataChange('optimizer', e.target.value)}
            required
          />
        </div>

        <div className="form-group centered">
          <button onClick={handleFileUpload} disabled={isUploading}>
            {isUploading ? 'Uploading...' : 'Upload'}
          </button>
        </div>

        {message && (
          <div className="message-box">
            <div>{message}</div>
            {batchId && (
              <>
                <br />
                <div>Batch ID: {batchId}</div>
              </>
            )}
          </div>
        )}
      </div>
    </div>
  );
}

export default withAuthenticator(SubmitTFTask);
